diff --git a/src/router/modelRouter.ts b/src/router/modelRouter.ts index 4b6f8ea..3a5156b 100644 --- a/src/router/modelRouter.ts +++ b/src/router/modelRouter.ts @@ -2,10 +2,16 @@ import { ModelRouterOptions, FailureContext, RoutingDecision, - RoutingStrategy + RoutingStrategy, + SmartRoutingOptions } from "./types"; import { errorDetector } from "./errorDetector"; -import { fallbackStrategy, contextStrategy, costStrategy } from "./routingStrategies"; +import { + fallbackStrategy, + contextStrategy, + costStrategy, + smartStrategy +} from "./routingStrategies"; import { apiKeyManager } from "./apiKeyManager"; /** @@ -17,12 +23,14 @@ export class ModelRouter { private fallbackMap: Record; private maxRetries: number; private crossProviderEnabled: boolean; + private smartRouting: SmartRoutingOptions; constructor(options: ModelRouterOptions) { this.strategy = options.strategy; this.fallbackMap = options.fallbackMap || {}; this.maxRetries = options.maxRetries ?? 1; this.crossProviderEnabled = options.enableCrossProvider ?? false; + this.smartRouting = options.smartRouting || {}; // Register API keys if provided if (options.apiKeys) { @@ -72,6 +80,57 @@ export class ModelRouter { } } } + + if (this.strategy === "smart") { + const threshold = this.smartRouting.confidenceThreshold; + if ( + threshold !== undefined && + (typeof threshold !== "number" || threshold < 0 || threshold > 1) + ) { + throw new Error( + "TokenFirewall Router: smartRouting.confidenceThreshold must be between 0 and 1" + ); + } + + this.validateModelMap("smartRouting.taskModelMap", this.smartRouting.taskModelMap); + this.validateModelList("smartRouting.fallbackModels", this.smartRouting.fallbackModels); + } + } + + private validateModelMap(name: string, value?: Record): void { + if (value === undefined) { + return; + } + + if (!value || typeof value !== "object" || Array.isArray(value)) { + throw new Error(`TokenFirewall Router: ${name} must be an object`); + } + + for (const [task, model] of Object.entries(value)) { + if (!task.trim() || typeof model !== "string" || !model.trim()) { + throw new Error( + `TokenFirewall Router: ${name} entries must map non-empty task names to model names` + ); + } + } + } + + private validateModelList(name: string, value?: string[]): void { + if (value === undefined) { + return; + } + + if (!Array.isArray(value)) { + throw new Error(`TokenFirewall Router: ${name} must be an array`); + } + + for (const model of value) { + if (typeof model !== "string" || !model.trim()) { + throw new Error( + `TokenFirewall Router: ${name} must only contain non-empty model names` + ); + } + } } /** @@ -137,6 +196,9 @@ export class ModelRouter { case "cost": return costStrategy(context, failureType as any); + case "smart": + return smartStrategy(context, failureType as any, this.smartRouting); + default: throw new Error( `TokenFirewall Router: Unknown strategy "${this.strategy}"` diff --git a/src/router/routingStrategies.ts b/src/router/routingStrategies.ts index ca5533a..c7bea8d 100644 --- a/src/router/routingStrategies.ts +++ b/src/router/routingStrategies.ts @@ -1,4 +1,9 @@ -import { FailureContext, RoutingDecision, FailureType } from "./types"; +import { + FailureContext, + RoutingDecision, + FailureType, + SmartRoutingOptions +} from "./types"; import { contextRegistry } from "../introspection/contextRegistry"; import { pricingRegistry } from "../core/pricingRegistry"; @@ -185,6 +190,185 @@ export function costStrategy( }; } +const DEFAULT_SMART_TASK_MODELS: Record = { + code: "gpt-4.1", + analysis: "gpt-4o", + math: "o1-mini", + summarization: "gpt-4o-mini", + chat: "gpt-4o-mini" +}; + +const SMART_TASK_KEYWORDS: Record = { + code: [ + "bug", + "code", + "debug", + "function", + "refactor", + "stack trace", + "typescript", + "unit test" + ], + analysis: [ + "analyze", + "compare", + "evaluate", + "explain", + "insight", + "recommend", + "tradeoff" + ], + math: [ + "calculate", + "equation", + "math", + "probability", + "proof", + "solve", + "statistics" + ], + summarization: [ + "brief", + "condense", + "notes", + "recap", + "summarize", + "summary", + "tl;dr" + ], + chat: [ + "chat", + "conversation", + "friendly", + "reply", + "rewrite", + "tone" + ] +}; + +/** + * Smart routing strategy + * Classifies the request body and selects a task-specific model. + */ +export function smartStrategy( + context: FailureContext, + failureType: FailureType, + options: SmartRoutingOptions = {} +): RoutingDecision { + const confidenceThreshold = options.confidenceThreshold ?? 0.35; + const taskModelMap = { + ...DEFAULT_SMART_TASK_MODELS, + ...(options.taskModelMap || {}) + }; + + const classification = classifyRequestTask(context.requestBody); + + if (classification.confidence >= confidenceThreshold) { + const nextModel = taskModelMap[classification.task]; + + if (nextModel && !context.attemptedModels.includes(nextModel)) { + return { + retry: true, + nextModel, + reason: + `Smart routing selected ${classification.task} model ` + + `after ${failureType} (confidence ${classification.confidence.toFixed(2)})` + }; + } + } + + const fallbackModel = (options.fallbackModels || []) + .find(model => model !== context.originalModel && !context.attemptedModels.includes(model)); + + if (fallbackModel) { + return { + retry: true, + nextModel: fallbackModel, + reason: + `Smart routing used fallback model after ${failureType}; ` + + `task confidence was ${classification.confidence.toFixed(2)}` + }; + } + + return { + retry: false, + reason: + `Smart routing could not find an eligible model ` + + `(task=${classification.task}, confidence=${classification.confidence.toFixed(2)})` + }; +} + +function classifyRequestTask(requestBody: unknown): { task: string; confidence: number } { + const text = extractPromptText(requestBody).toLowerCase(); + + if (!text) { + return { task: "chat", confidence: 0 }; + } + + const scores = Object.entries(SMART_TASK_KEYWORDS) + .map(([task, keywords]) => ({ + task, + score: keywords.reduce((total, keyword) => { + return total + (text.includes(keyword) ? 1 : 0); + }, 0), + keywordCount: keywords.length + })) + .sort((a, b) => b.score - a.score); + + const best = scores[0]; + if (!best || best.score === 0) { + return { task: "chat", confidence: 0.2 }; + } + + return { + task: best.task, + confidence: Math.min(0.95, 0.25 + best.score / Math.max(best.keywordCount, 1)) + }; +} + +function extractPromptText(value: unknown): string { + if (!value) { + return ""; + } + + if (typeof value === "string") { + return value; + } + + if (Array.isArray(value)) { + return value.map(extractPromptText).filter(Boolean).join(" "); + } + + if (typeof value !== "object") { + return ""; + } + + const record = value as Record; + const direct = [ + record.prompt, + record.input, + record.query, + record.text, + record.content + ] + .map(extractPromptText) + .filter(Boolean); + + const messages = Array.isArray(record.messages) + ? record.messages.map(message => { + if (typeof message === "string") { + return message; + } + if (message && typeof message === "object") { + return extractPromptText((message as Record).content); + } + return ""; + }) + : []; + + return [...direct, ...messages].filter(Boolean).join(" "); +} + /** * Helper to get known models for a provider * Uses context registry for dynamic model discovery, falls back to static list diff --git a/src/router/types.ts b/src/router/types.ts index 8e333f4..1795a9a 100644 --- a/src/router/types.ts +++ b/src/router/types.ts @@ -5,7 +5,7 @@ /** * Routing strategy types */ -export type RoutingStrategy = "fallback" | "context" | "cost"; +export type RoutingStrategy = "fallback" | "context" | "cost" | "smart"; /** * Failure types detected by error detector @@ -29,6 +29,18 @@ export interface ApiKeyConfig { [key: string]: string | undefined; } +/** + * Configuration for the smart routing strategy. + */ +export interface SmartRoutingOptions { + /** Minimum classifier confidence required before choosing a task-specific model */ + confidenceThreshold?: number; + /** Model to use per detected task type */ + taskModelMap?: Record; + /** Fallback models to try when confidence is low or the task-specific model was attempted */ + fallbackModels?: string[]; +} + /** * Configuration options for model router */ @@ -43,6 +55,8 @@ export interface ModelRouterOptions { apiKeys?: ApiKeyConfig; /** Enable cross-provider fallback (default: false) */ enableCrossProvider?: boolean; + /** Smart routing configuration */ + smartRouting?: SmartRoutingOptions; } /** diff --git a/tests/smart-strategy.test.js b/tests/smart-strategy.test.js new file mode 100644 index 0000000..682197c --- /dev/null +++ b/tests/smart-strategy.test.js @@ -0,0 +1,85 @@ +const assert = require("assert"); +const { createModelRouter, disableModelRouter } = require("../dist/index.js"); + +function rateLimitError() { + return { + status: 429, + response: { data: { error: { message: "rate limit exceeded" } } } + }; +} + +function route(router, requestBody, attemptedModels = ["gpt-4o-mini"]) { + return router.handleFailure({ + error: rateLimitError(), + originalModel: "gpt-4o-mini", + requestBody, + provider: "openai", + retryCount: 0, + attemptedModels + }); +} + +try { + const router = createModelRouter({ + strategy: "smart", + maxRetries: 2, + smartRouting: { + confidenceThreshold: 0.3, + taskModelMap: { + code: "gpt-4.1", + math: "o1-mini" + }, + fallbackModels: ["gpt-4.1-mini"] + } + }); + + const codeDecision = route(router, { + messages: [ + { + role: "user", + content: "Please debug this TypeScript function bug and add a unit test." + } + ] + }); + assert.equal(codeDecision.retry, true); + assert.equal(codeDecision.nextModel, "gpt-4.1"); + assert.match(codeDecision.reason, /Smart routing selected code model/); + + const mathDecision = route(router, { + prompt: "Solve this probability equation and explain the statistics." + }); + assert.equal(mathDecision.retry, true); + assert.equal(mathDecision.nextModel, "o1-mini"); + + const fallbackRouter = createModelRouter({ + strategy: "smart", + maxRetries: 1, + smartRouting: { + confidenceThreshold: 0.9, + fallbackModels: ["gpt-4.1-mini"] + } + }); + const fallbackDecision = route(fallbackRouter, { + messages: [{ role: "user", content: "Hello there." }] + }); + assert.equal(fallbackDecision.retry, true); + assert.equal(fallbackDecision.nextModel, "gpt-4.1-mini"); + assert.match(fallbackDecision.reason, /fallback model/); + + const attemptedDecision = route(router, { + input: "Refactor this code and debug the failing function." + }, ["gpt-4o-mini", "gpt-4.1", "gpt-4.1-mini"]); + assert.equal(attemptedDecision.retry, false); + assert.match(attemptedDecision.reason, /could not find an eligible model/); + + assert.throws(() => { + createModelRouter({ + strategy: "smart", + smartRouting: { confidenceThreshold: 1.5 } + }); + }, /confidenceThreshold must be between 0 and 1/); + + console.log("smart-strategy tests passed"); +} finally { + disableModelRouter(); +}