Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,12 @@ Creates and configures an intelligent model router.

```typescript
interface ModelRouterOptions {
strategy: "fallback" | "context" | "cost"; // Routing strategy
strategy: "fallback" | "context" | "cost" | "smart"; // Routing strategy
fallbackMap?: Record<string, string[]>; // Fallback model map
taskClassification?: Record<string, TaskClassificationRule>;
modelOverrides?: Record<string, string>;
confidenceThreshold?: number; // Smart routing threshold (default: 0.7)
defaultModel?: string; // Smart routing fallback model
maxRetries?: number; // Max retry attempts (default: 1)
}
```
Expand Down Expand Up @@ -324,6 +328,11 @@ patchGlobalFetch();
- Selects cheaper model from same provider
- Best for: Cost optimization, rate limit handling

**4. Smart Strategy** - Selects a model from task classification
- Detects task intent from prompt/message content
- Supports custom task rules, model overrides, confidence threshold, and default model
- Best for: Task-aware model quality and cost optimization

### Error Detection

The router automatically detects and classifies failures:
Expand Down
2 changes: 2 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
"types": "dist/index.d.ts",
"scripts": {
"build": "tsc",
"test": "npm run build && npm run test:smart",
"test:smart": "node tests/smart-router.test.js",
"prepublishOnly": "npm run build"
},
"keywords": [
Expand Down
4 changes: 3 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,9 @@ export type {
FailureContext,
RoutingDecision,
RouterEvent,
ApiKeyConfig
ApiKeyConfig,
TaskClassification,
TaskClassificationRule
} from "./router/types";

/**
Expand Down
103 changes: 102 additions & 1 deletion src/router/modelRouter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ import {
ModelRouterOptions,
FailureContext,
RoutingDecision,
RoutingStrategy
RoutingStrategy,
TaskClassificationRule
} from "./types";
import { errorDetector } from "./errorDetector";
import { fallbackStrategy, contextStrategy, costStrategy } from "./routingStrategies";
import { apiKeyManager } from "./apiKeyManager";
import { detectProvider } from "./providerDetector";
import { TaskClassifier } from "./taskClassifier";

/**
* Intelligent Model Router
Expand All @@ -17,12 +20,26 @@ export class ModelRouter {
private fallbackMap: Record<string, string[]>;
private maxRetries: number;
private crossProviderEnabled: boolean;
private taskClassifier: TaskClassifier | null;
private taskClassification?: Record<string, TaskClassificationRule>;
private confidenceThreshold: number;
private defaultModel?: string;

constructor(options: ModelRouterOptions) {
this.strategy = options.strategy;
this.fallbackMap = options.fallbackMap || {};
this.maxRetries = options.maxRetries ?? 1;
this.crossProviderEnabled = options.enableCrossProvider ?? false;
this.taskClassification = options.taskClassification;
this.confidenceThreshold = options.confidenceThreshold ?? 0.7;
this.defaultModel = options.defaultModel;
this.taskClassifier =
options.strategy === "smart"
? new TaskClassifier(
options.taskClassification,
options.modelOverrides
)
: null;

// Register API keys if provided
if (options.apiKeys) {
Expand All @@ -46,6 +63,25 @@ export class ModelRouter {
);
}

if (this.confidenceThreshold < 0 || this.confidenceThreshold > 1) {
throw new Error(
"TokenFirewall Router: confidenceThreshold must be between 0 and 1"
);
}

if (
this.defaultModel !== undefined &&
(typeof this.defaultModel !== "string" || this.defaultModel.trim() === "")
) {
throw new Error(
"TokenFirewall Router: defaultModel must be a non-empty string when provided"
);
}

if (this.strategy === "smart") {
this.validateTaskClassification();
}

if (this.strategy === "fallback") {
if (Object.keys(this.fallbackMap).length === 0) {
throw new Error(
Expand Down Expand Up @@ -137,6 +173,9 @@ export class ModelRouter {
case "cost":
return costStrategy(context, failureType as any);

case "smart":
return this.smartStrategy(context, failureType);

default:
throw new Error(
`TokenFirewall Router: Unknown strategy "${this.strategy}"`
Expand Down Expand Up @@ -164,4 +203,66 @@ export class ModelRouter {
public isCrossProviderEnabled(): boolean {
return this.crossProviderEnabled;
}

/**
* Validate custom smart-routing rules.
*/
private validateTaskClassification(): void {
if (!this.taskClassification) {
return;
}

for (const [taskType, rule] of Object.entries(this.taskClassification)) {
if (!rule || typeof rule !== "object") {
throw new Error(
`TokenFirewall Router: smart task "${taskType}" must be an object`
);
}
}
}

/**
* Task-aware model selection for smart routing.
*/
private smartStrategy(
context: FailureContext,
failureType: string
): RoutingDecision {
const classification = this.taskClassifier?.classify(context.requestBody);
const selectedModel =
classification && classification.confidence >= this.confidenceThreshold
? classification.selectedModel
: this.defaultModel;

if (!selectedModel) {
const confidence = classification
? ` (confidence ${classification.confidence.toFixed(2)})`
: "";
return {
retry: false,
reason: `Smart strategy could not classify request above threshold${confidence}`
};
}

const selectedProvider = detectProvider(selectedModel);
if (
selectedProvider &&
selectedProvider !== context.provider &&
!this.crossProviderEnabled
) {
return {
retry: false,
reason:
`Smart strategy selected ${selectedModel}, but cross-provider routing is disabled`
};
}

return {
retry: true,
nextModel: selectedModel,
reason: classification
? `${classification.reason} after ${failureType}`
: `Using default smart-routing model after ${failureType}`
};
}
}
Loading