From 33a8ad6bdb751850bdda6af7e4402faf07156326 Mon Sep 17 00:00:00 2001 From: Saurabh Kumar Bajpai <157192462+saurabhhhcodes@users.noreply.github.com> Date: Thu, 28 May 2026 16:37:26 +0530 Subject: [PATCH] feat: add smart model routing strategy --- README.md | 11 +- package.json | 2 + src/index.ts | 4 +- src/router/modelRouter.ts | 103 +++++++++++++++- src/router/taskClassifier.ts | 228 +++++++++++++++++++++++++++++++++++ src/router/types.ts | 40 +++++- tests/smart-router.test.js | 147 ++++++++++++++++++++++ 7 files changed, 531 insertions(+), 4 deletions(-) create mode 100644 src/router/taskClassifier.ts create mode 100644 tests/smart-router.test.js diff --git a/README.md b/README.md index 57253f1..15eea82 100644 --- a/README.md +++ b/README.md @@ -283,8 +283,12 @@ Creates and configures an intelligent model router. ```typescript interface ModelRouterOptions { - strategy: "fallback" | "context" | "cost"; // Routing strategy + strategy: "fallback" | "context" | "cost" | "smart"; // Routing strategy fallbackMap?: Record; // Fallback model map + taskClassification?: Record; + modelOverrides?: Record; + confidenceThreshold?: number; // Smart routing threshold (default: 0.7) + defaultModel?: string; // Smart routing fallback model maxRetries?: number; // Max retry attempts (default: 1) } ``` @@ -324,6 +328,11 @@ patchGlobalFetch(); - Selects cheaper model from same provider - Best for: Cost optimization, rate limit handling +**4. Smart Strategy** - Selects a model from task classification +- Detects task intent from prompt/message content +- Supports custom task rules, model overrides, confidence threshold, and default model +- Best for: Task-aware model quality and cost optimization + ### Error Detection The router automatically detects and classifies failures: diff --git a/package.json b/package.json index 6f80c8a..f6084cc 100644 --- a/package.json +++ b/package.json @@ -6,6 +6,8 @@ "types": "dist/index.d.ts", "scripts": { "build": "tsc", + "test": "npm run build && npm run test:smart", + "test:smart": "node tests/smart-router.test.js", "prepublishOnly": "npm run build" }, "keywords": [ diff --git a/src/index.ts b/src/index.ts index c769e99..adb84c1 100644 --- a/src/index.ts +++ b/src/index.ts @@ -249,7 +249,9 @@ export type { FailureContext, RoutingDecision, RouterEvent, - ApiKeyConfig + ApiKeyConfig, + TaskClassification, + TaskClassificationRule } from "./router/types"; /** diff --git a/src/router/modelRouter.ts b/src/router/modelRouter.ts index 4b6f8ea..de59baa 100644 --- a/src/router/modelRouter.ts +++ b/src/router/modelRouter.ts @@ -2,11 +2,14 @@ import { ModelRouterOptions, FailureContext, RoutingDecision, - RoutingStrategy + RoutingStrategy, + TaskClassificationRule } from "./types"; import { errorDetector } from "./errorDetector"; import { fallbackStrategy, contextStrategy, costStrategy } from "./routingStrategies"; import { apiKeyManager } from "./apiKeyManager"; +import { detectProvider } from "./providerDetector"; +import { TaskClassifier } from "./taskClassifier"; /** * Intelligent Model Router @@ -17,12 +20,26 @@ export class ModelRouter { private fallbackMap: Record; private maxRetries: number; private crossProviderEnabled: boolean; + private taskClassifier: TaskClassifier | null; + private taskClassification?: Record; + private confidenceThreshold: number; + private defaultModel?: string; constructor(options: ModelRouterOptions) { this.strategy = options.strategy; this.fallbackMap = options.fallbackMap || {}; this.maxRetries = options.maxRetries ?? 1; this.crossProviderEnabled = options.enableCrossProvider ?? false; + this.taskClassification = options.taskClassification; + this.confidenceThreshold = options.confidenceThreshold ?? 0.7; + this.defaultModel = options.defaultModel; + this.taskClassifier = + options.strategy === "smart" + ? new TaskClassifier( + options.taskClassification, + options.modelOverrides + ) + : null; // Register API keys if provided if (options.apiKeys) { @@ -46,6 +63,25 @@ export class ModelRouter { ); } + if (this.confidenceThreshold < 0 || this.confidenceThreshold > 1) { + throw new Error( + "TokenFirewall Router: confidenceThreshold must be between 0 and 1" + ); + } + + if ( + this.defaultModel !== undefined && + (typeof this.defaultModel !== "string" || this.defaultModel.trim() === "") + ) { + throw new Error( + "TokenFirewall Router: defaultModel must be a non-empty string when provided" + ); + } + + if (this.strategy === "smart") { + this.validateTaskClassification(); + } + if (this.strategy === "fallback") { if (Object.keys(this.fallbackMap).length === 0) { throw new Error( @@ -137,6 +173,9 @@ export class ModelRouter { case "cost": return costStrategy(context, failureType as any); + case "smart": + return this.smartStrategy(context, failureType); + default: throw new Error( `TokenFirewall Router: Unknown strategy "${this.strategy}"` @@ -164,4 +203,66 @@ export class ModelRouter { public isCrossProviderEnabled(): boolean { return this.crossProviderEnabled; } + + /** + * Validate custom smart-routing rules. + */ + private validateTaskClassification(): void { + if (!this.taskClassification) { + return; + } + + for (const [taskType, rule] of Object.entries(this.taskClassification)) { + if (!rule || typeof rule !== "object") { + throw new Error( + `TokenFirewall Router: smart task "${taskType}" must be an object` + ); + } + } + } + + /** + * Task-aware model selection for smart routing. + */ + private smartStrategy( + context: FailureContext, + failureType: string + ): RoutingDecision { + const classification = this.taskClassifier?.classify(context.requestBody); + const selectedModel = + classification && classification.confidence >= this.confidenceThreshold + ? classification.selectedModel + : this.defaultModel; + + if (!selectedModel) { + const confidence = classification + ? ` (confidence ${classification.confidence.toFixed(2)})` + : ""; + return { + retry: false, + reason: `Smart strategy could not classify request above threshold${confidence}` + }; + } + + const selectedProvider = detectProvider(selectedModel); + if ( + selectedProvider && + selectedProvider !== context.provider && + !this.crossProviderEnabled + ) { + return { + retry: false, + reason: + `Smart strategy selected ${selectedModel}, but cross-provider routing is disabled` + }; + } + + return { + retry: true, + nextModel: selectedModel, + reason: classification + ? `${classification.reason} after ${failureType}` + : `Using default smart-routing model after ${failureType}` + }; + } } diff --git a/src/router/taskClassifier.ts b/src/router/taskClassifier.ts new file mode 100644 index 0000000..c29e693 --- /dev/null +++ b/src/router/taskClassifier.ts @@ -0,0 +1,228 @@ +import { TaskClassification, TaskClassificationRule } from "./types"; + +interface ScoredRule { + taskType: string; + rule: Required> & + Omit; +} + +const DEFAULT_RULES: Record = { + code_generation: { + model: "claude-3-5-sonnet-20241022", + reason: "Code generation task detected", + keywords: [ + "write code", + "create function", + "implement", + "build", + "develop", + "program" + ], + patterns: [/write.*code/i, /create.*function/i, /implement.*class/i], + priority: 10 + }, + code_review: { + model: "claude-3-5-sonnet-20241022", + reason: "Code review or refactoring task detected", + keywords: ["review code", "find bugs", "refactor", "debug", "optimize"], + patterns: [/review.*code/i, /find.*bug/i, /refactor/i, /debug/i], + priority: 9 + }, + math_reasoning: { + model: "o1-mini", + reason: "Math or reasoning task detected", + keywords: ["calculate", "solve", "equation", "formula", "math"], + patterns: [/solve.*equation/i, /calculate/i, /mathematical/i], + priority: 8 + }, + document_analysis: { + model: "gemini-2.5-pro", + reason: "Long document or summarization task detected", + keywords: ["summarize document", "analyze document", "extract from"], + patterns: [/summarize.*document/i, /analyze.*pdf/i, /extract.*information/i], + priority: 7 + }, + creative_writing: { + model: "gpt-4o", + reason: "Creative writing task detected", + keywords: ["write story", "blog post", "article", "creative"], + patterns: [/write.*story/i, /creative.*writing/i, /blog.*post/i], + priority: 6 + }, + translation: { + model: "gpt-4o-mini", + reason: "Translation task detected", + keywords: ["translate", "translation", "convert to"], + patterns: [/translate.*to/i, /translation/i], + priority: 5 + }, + simple_chat: { + model: "gpt-4o-mini", + reason: "Simple chat task detected", + keywords: ["hello", "hi", "thanks", "thank you", "help"], + patterns: [/^(hi|hello|hey)\b/i, /how.*are.*you/i, /thank/i], + priority: 1 + } +}; + +/** + * Lightweight rule-based classifier for smart model routing. + */ +export class TaskClassifier { + private rules: ScoredRule[]; + private modelOverrides: Record; + + constructor( + taskClassification: Record = {}, + modelOverrides: Record = {} + ) { + const mergedRules = { + ...DEFAULT_RULES, + ...taskClassification + }; + + this.rules = Object.entries(mergedRules) + .map(([taskType, rule]) => this.normalizeRule(taskType, rule)) + .sort((a, b) => (b.rule.priority ?? 0) - (a.rule.priority ?? 0)); + + this.modelOverrides = modelOverrides; + } + + /** + * Classify a provider request body and select the best task model. + */ + public classify(requestBody: unknown): TaskClassification | null { + const prompt = this.extractPrompt(requestBody).trim(); + + if (!prompt) { + return null; + } + + const normalizedPrompt = prompt.toLowerCase(); + let best: TaskClassification | null = null; + let bestScore = 0; + + for (const { taskType, rule } of this.rules) { + const keywordMatches = (rule.keywords ?? []).filter(keyword => + normalizedPrompt.includes(keyword.toLowerCase()) + ).length; + + const patternMatches = (rule.patterns ?? []).filter(pattern => + pattern.test(prompt) + ).length; + + if (keywordMatches === 0 && patternMatches === 0) { + continue; + } + + const score = keywordMatches + patternMatches * 2 + (rule.priority ?? 0) / 100; + const confidence = Math.min(0.99, 0.6 + keywordMatches * 0.12 + patternMatches * 0.18); + + if (score > bestScore) { + bestScore = score; + best = { + taskType, + confidence, + selectedModel: this.modelOverrides[taskType] ?? rule.model, + reason: rule.reason ?? `Matched smart-routing task "${taskType}"` + }; + } + } + + return best; + } + + private normalizeRule(taskType: string, rule: TaskClassificationRule): ScoredRule { + if (!rule.model || typeof rule.model !== "string" || rule.model.trim() === "") { + throw new Error( + `TokenFirewall Router: smart task "${taskType}" requires a non-empty model` + ); + } + + return { + taskType, + rule: { + ...rule, + model: rule.model + } + }; + } + + private extractPrompt(value: unknown): string { + if (!value || typeof value !== "object") { + return ""; + } + + const body = value as Record; + const directPrompt = this.firstString(body.prompt, body.input, body.text); + + if (directPrompt) { + return directPrompt; + } + + if (Array.isArray(body.messages)) { + return body.messages + .map(message => { + if (!message || typeof message !== "object") { + return ""; + } + return this.extractContent((message as Record).content); + }) + .filter(Boolean) + .join("\n"); + } + + if (Array.isArray(body.contents)) { + return body.contents + .map(content => { + if (!content || typeof content !== "object") { + return ""; + } + const parts = (content as Record).parts; + if (!Array.isArray(parts)) { + return ""; + } + return parts + .map(part => + part && typeof part === "object" + ? this.firstString((part as Record).text) + : "" + ) + .filter(Boolean) + .join("\n"); + }) + .filter(Boolean) + .join("\n"); + } + + return ""; + } + + private extractContent(content: unknown): string { + if (typeof content === "string") { + return content; + } + + if (!Array.isArray(content)) { + return ""; + } + + return content + .map(part => + part && typeof part === "object" + ? this.firstString((part as Record).text) + : "" + ) + .filter(Boolean) + .join("\n"); + } + + private firstString(...values: unknown[]): string { + for (const value of values) { + if (typeof value === "string" && value.trim() !== "") { + return value; + } + } + return ""; + } +} diff --git a/src/router/types.ts b/src/router/types.ts index 8e333f4..70ff58c 100644 --- a/src/router/types.ts +++ b/src/router/types.ts @@ -5,7 +5,7 @@ /** * Routing strategy types */ -export type RoutingStrategy = "fallback" | "context" | "cost"; +export type RoutingStrategy = "fallback" | "context" | "cost" | "smart"; /** * Failure types detected by error detector @@ -37,6 +37,14 @@ export interface ModelRouterOptions { strategy: RoutingStrategy; /** Map of primary models to fallback models */ fallbackMap?: Record; + /** Optional task classification rules for smart routing */ + taskClassification?: Record; + /** Override the model selected for a task type */ + modelOverrides?: Record; + /** Minimum confidence required for smart routing, from 0 to 1 (default: 0.7) */ + confidenceThreshold?: number; + /** Fallback model when no task-specific model can be selected */ + defaultModel?: string; /** Maximum number of retry attempts (default: 1) */ maxRetries?: number; /** API keys for cross-provider fallback */ @@ -45,6 +53,36 @@ export interface ModelRouterOptions { enableCrossProvider?: boolean; } +/** + * Smart-routing task configuration + */ +export interface TaskClassificationRule { + /** Preferred model for the task */ + model: string; + /** Human-readable selection reason */ + reason?: string; + /** Plain-text phrases that identify the task */ + keywords?: string[]; + /** Regex patterns that identify the task */ + patterns?: RegExp[]; + /** Higher-priority rules win ties */ + priority?: number; +} + +/** + * Result from the task classifier + */ +export interface TaskClassification { + /** Matched task type */ + taskType: string; + /** Confidence score from 0 to 1 */ + confidence: number; + /** Selected model for the task */ + selectedModel: string; + /** Reason for the classification */ + reason: string; +} + /** * Context information about a failed request */ diff --git a/tests/smart-router.test.js b/tests/smart-router.test.js new file mode 100644 index 0000000..991cd35 --- /dev/null +++ b/tests/smart-router.test.js @@ -0,0 +1,147 @@ +/** + * Smart router tests. + * + * Run after building TypeScript: + * npm run build && node tests/smart-router.test.js + */ + +const assert = require("assert"); +const { createModelRouter, disableModelRouter } = require("../dist/index.js"); + +function failureContext(requestBody, overrides = {}) { + return { + error: { status: 429 }, + originalModel: "gpt-4o", + requestBody, + provider: "openai", + retryCount: 0, + attemptedModels: ["gpt-4o"], + ...overrides + }; +} + +function test(name, fn) { + try { + fn(); + console.log(`ok - ${name}`); + } catch (error) { + console.error(`not ok - ${name}`); + throw error; + } finally { + disableModelRouter(); + } +} + +test("smart strategy routes code tasks to the classifier-selected model", () => { + const router = createModelRouter({ + strategy: "smart", + enableCrossProvider: true, + maxRetries: 1 + }); + + const decision = router.handleFailure( + failureContext({ + messages: [ + { + role: "user", + content: "Write code for a TypeScript rate limiter" + } + ] + }) + ); + + assert.strictEqual(decision.retry, true); + assert.strictEqual(decision.nextModel, "claude-3-5-sonnet-20241022"); + assert.match(decision.reason, /Code generation/); +}); + +test("smart strategy blocks cross-provider routing unless enabled", () => { + const router = createModelRouter({ + strategy: "smart", + maxRetries: 1 + }); + + const decision = router.handleFailure( + failureContext({ + messages: [ + { + role: "user", + content: "Implement a JavaScript helper function" + } + ] + }) + ); + + assert.strictEqual(decision.retry, false); + assert.match(decision.reason, /cross-provider routing is disabled/); +}); + +test("smart strategy uses default model when confidence is below threshold", () => { + const router = createModelRouter({ + strategy: "smart", + confidenceThreshold: 0.99, + defaultModel: "gpt-4o-mini", + maxRetries: 1 + }); + + const decision = router.handleFailure( + failureContext({ + messages: [ + { + role: "user", + content: "This message has no strong routing signal" + } + ] + }) + ); + + assert.strictEqual(decision.retry, true); + assert.strictEqual(decision.nextModel, "gpt-4o-mini"); +}); + +test("smart strategy validates confidence threshold", () => { + assert.throws( + () => + createModelRouter({ + strategy: "smart", + confidenceThreshold: 2 + }), + /confidenceThreshold must be between 0 and 1/ + ); +}); + +test("smart strategy supports custom task rules and model overrides", () => { + const router = createModelRouter({ + strategy: "smart", + enableCrossProvider: true, + taskClassification: { + sql_generation: { + model: "gpt-4o", + reason: "SQL task detected", + keywords: ["write sql"], + priority: 20 + } + }, + modelOverrides: { + sql_generation: "claude-3-5-sonnet-20241022" + }, + maxRetries: 1 + }); + + const decision = router.handleFailure( + failureContext({ + messages: [ + { + role: "user", + content: "Write SQL to aggregate monthly revenue by customer" + } + ] + }) + ); + + assert.strictEqual(decision.retry, true); + assert.strictEqual(decision.nextModel, "claude-3-5-sonnet-20241022"); + assert.match(decision.reason, /SQL task detected/); +}); + +console.log("smart router tests passed");