Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 64 additions & 2 deletions src/router/modelRouter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@ import {
ModelRouterOptions,
FailureContext,
RoutingDecision,
RoutingStrategy
RoutingStrategy,
SmartRoutingOptions
} from "./types";
import { errorDetector } from "./errorDetector";
import { fallbackStrategy, contextStrategy, costStrategy } from "./routingStrategies";
import {
fallbackStrategy,
contextStrategy,
costStrategy,
smartStrategy
} from "./routingStrategies";
import { apiKeyManager } from "./apiKeyManager";

/**
Expand All @@ -17,12 +23,14 @@ export class ModelRouter {
private fallbackMap: Record<string, string[]>;
private maxRetries: number;
private crossProviderEnabled: boolean;
private smartRouting: SmartRoutingOptions;

constructor(options: ModelRouterOptions) {
this.strategy = options.strategy;
this.fallbackMap = options.fallbackMap || {};
this.maxRetries = options.maxRetries ?? 1;
this.crossProviderEnabled = options.enableCrossProvider ?? false;
this.smartRouting = options.smartRouting || {};

// Register API keys if provided
if (options.apiKeys) {
Expand Down Expand Up @@ -72,6 +80,57 @@ export class ModelRouter {
}
}
}

if (this.strategy === "smart") {
const threshold = this.smartRouting.confidenceThreshold;
if (
threshold !== undefined &&
(typeof threshold !== "number" || threshold < 0 || threshold > 1)
) {
throw new Error(
"TokenFirewall Router: smartRouting.confidenceThreshold must be between 0 and 1"
);
}

this.validateModelMap("smartRouting.taskModelMap", this.smartRouting.taskModelMap);
this.validateModelList("smartRouting.fallbackModels", this.smartRouting.fallbackModels);
}
}

private validateModelMap(name: string, value?: Record<string, string>): void {
if (value === undefined) {
return;
}

if (!value || typeof value !== "object" || Array.isArray(value)) {
throw new Error(`TokenFirewall Router: ${name} must be an object`);
}

for (const [task, model] of Object.entries(value)) {
if (!task.trim() || typeof model !== "string" || !model.trim()) {
throw new Error(
`TokenFirewall Router: ${name} entries must map non-empty task names to model names`
);
}
}
}

private validateModelList(name: string, value?: string[]): void {
if (value === undefined) {
return;
}

if (!Array.isArray(value)) {
throw new Error(`TokenFirewall Router: ${name} must be an array`);
}

for (const model of value) {
if (typeof model !== "string" || !model.trim()) {
throw new Error(
`TokenFirewall Router: ${name} must only contain non-empty model names`
);
}
}
}

/**
Expand Down Expand Up @@ -137,6 +196,9 @@ export class ModelRouter {
case "cost":
return costStrategy(context, failureType as any);

case "smart":
return smartStrategy(context, failureType as any, this.smartRouting);

default:
throw new Error(
`TokenFirewall Router: Unknown strategy "${this.strategy}"`
Expand Down
186 changes: 185 additions & 1 deletion src/router/routingStrategies.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import { FailureContext, RoutingDecision, FailureType } from "./types";
import {
FailureContext,
RoutingDecision,
FailureType,
SmartRoutingOptions
} from "./types";
import { contextRegistry } from "../introspection/contextRegistry";
import { pricingRegistry } from "../core/pricingRegistry";

Expand Down Expand Up @@ -185,6 +190,185 @@ export function costStrategy(
};
}

const DEFAULT_SMART_TASK_MODELS: Record<string, string> = {
code: "gpt-4.1",
analysis: "gpt-4o",
math: "o1-mini",
summarization: "gpt-4o-mini",
chat: "gpt-4o-mini"
};

const SMART_TASK_KEYWORDS: Record<string, string[]> = {
code: [
"bug",
"code",
"debug",
"function",
"refactor",
"stack trace",
"typescript",
"unit test"
],
analysis: [
"analyze",
"compare",
"evaluate",
"explain",
"insight",
"recommend",
"tradeoff"
],
math: [
"calculate",
"equation",
"math",
"probability",
"proof",
"solve",
"statistics"
],
summarization: [
"brief",
"condense",
"notes",
"recap",
"summarize",
"summary",
"tl;dr"
],
chat: [
"chat",
"conversation",
"friendly",
"reply",
"rewrite",
"tone"
]
};

/**
* Smart routing strategy
* Classifies the request body and selects a task-specific model.
*/
export function smartStrategy(
context: FailureContext,
failureType: FailureType,
options: SmartRoutingOptions = {}
): RoutingDecision {
const confidenceThreshold = options.confidenceThreshold ?? 0.35;
const taskModelMap = {
...DEFAULT_SMART_TASK_MODELS,
...(options.taskModelMap || {})
};

const classification = classifyRequestTask(context.requestBody);

if (classification.confidence >= confidenceThreshold) {
const nextModel = taskModelMap[classification.task];

if (nextModel && !context.attemptedModels.includes(nextModel)) {
return {
retry: true,
nextModel,
reason:
`Smart routing selected ${classification.task} model ` +
`after ${failureType} (confidence ${classification.confidence.toFixed(2)})`
};
}
}

const fallbackModel = (options.fallbackModels || [])
.find(model => model !== context.originalModel && !context.attemptedModels.includes(model));

if (fallbackModel) {
return {
retry: true,
nextModel: fallbackModel,
reason:
`Smart routing used fallback model after ${failureType}; ` +
`task confidence was ${classification.confidence.toFixed(2)}`
};
}

return {
retry: false,
reason:
`Smart routing could not find an eligible model ` +
`(task=${classification.task}, confidence=${classification.confidence.toFixed(2)})`
};
}

function classifyRequestTask(requestBody: unknown): { task: string; confidence: number } {
const text = extractPromptText(requestBody).toLowerCase();

if (!text) {
return { task: "chat", confidence: 0 };
}

const scores = Object.entries(SMART_TASK_KEYWORDS)
.map(([task, keywords]) => ({
task,
score: keywords.reduce((total, keyword) => {
return total + (text.includes(keyword) ? 1 : 0);
}, 0),
keywordCount: keywords.length
}))
.sort((a, b) => b.score - a.score);

const best = scores[0];
if (!best || best.score === 0) {
return { task: "chat", confidence: 0.2 };
}

return {
task: best.task,
confidence: Math.min(0.95, 0.25 + best.score / Math.max(best.keywordCount, 1))
};
}

function extractPromptText(value: unknown): string {
if (!value) {
return "";
}

if (typeof value === "string") {
return value;
}

if (Array.isArray(value)) {
return value.map(extractPromptText).filter(Boolean).join(" ");
}

if (typeof value !== "object") {
return "";
}

const record = value as Record<string, unknown>;
const direct = [
record.prompt,
record.input,
record.query,
record.text,
record.content
]
.map(extractPromptText)
.filter(Boolean);

const messages = Array.isArray(record.messages)
? record.messages.map(message => {
if (typeof message === "string") {
return message;
}
if (message && typeof message === "object") {
return extractPromptText((message as Record<string, unknown>).content);
}
return "";
})
: [];

return [...direct, ...messages].filter(Boolean).join(" ");
}

/**
* Helper to get known models for a provider
* Uses context registry for dynamic model discovery, falls back to static list
Expand Down
16 changes: 15 additions & 1 deletion src/router/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
/**
* Routing strategy types
*/
export type RoutingStrategy = "fallback" | "context" | "cost";
export type RoutingStrategy = "fallback" | "context" | "cost" | "smart";

/**
* Failure types detected by error detector
Expand All @@ -29,6 +29,18 @@ export interface ApiKeyConfig {
[key: string]: string | undefined;
}

/**
* Configuration for the smart routing strategy.
*/
export interface SmartRoutingOptions {
/** Minimum classifier confidence required before choosing a task-specific model */
confidenceThreshold?: number;
/** Model to use per detected task type */
taskModelMap?: Record<string, string>;
/** Fallback models to try when confidence is low or the task-specific model was attempted */
fallbackModels?: string[];
}

/**
* Configuration options for model router
*/
Expand All @@ -43,6 +55,8 @@ export interface ModelRouterOptions {
apiKeys?: ApiKeyConfig;
/** Enable cross-provider fallback (default: false) */
enableCrossProvider?: boolean;
/** Smart routing configuration */
smartRouting?: SmartRoutingOptions;
}

/**
Expand Down
Loading