Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { contextRegistry } from "./introspection/contextRegistry";
import { ModelRouter } from "./router/modelRouter";
import { ModelRouterOptions, ApiKeyConfig } from "./router/types";
import { apiKeyManager } from "./router/apiKeyManager";
import { taskRegistry, TaskRegistry } from "./router/taskRegistry";

let globalBudgetManager: BudgetManager | null = null;
let globalModelRouter: ModelRouter | null = null;
Expand Down Expand Up @@ -227,6 +228,7 @@ export async function listModels(options: Omit<ListModelsOptions, 'budgetManager

// Keep the original export for backward compatibility
export { listAvailableModels };
export { taskRegistry, TaskRegistry };

// Export types for TypeScript users
export type {
Expand All @@ -252,6 +254,11 @@ export type {
ApiKeyConfig
} from "./router/types";

export type {
TaskConfiguration,
TaskRegistration
} from "./router/taskRegistry";

/**
* Model configuration for bulk registration
*/
Expand Down
229 changes: 229 additions & 0 deletions src/router/taskRegistry.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
/**
* Task configuration registry for Smart Model Selection.
*/

export interface TaskConfiguration {
/** Stable task identifier, for example "code_generation". */
taskType: string;
/** Preferred model for this task. */
model: string;
/** Human-readable reason for selecting the model. */
reason: string;
/** Keyword hints used by lightweight classifiers. */
keywords?: string[];
/** Regex patterns used by pattern detectors. */
patterns?: RegExp[];
/** Higher-priority tasks are evaluated first. */
priority?: number;
/** Optional token threshold for context-sensitive tasks. */
contextThreshold?: number;
}

export type TaskRegistration = Omit<TaskConfiguration, "taskType"> & {
taskType?: string;
};

const BUILT_IN_TASKS: TaskConfiguration[] = [
{
taskType: "code_generation",
model: "claude-3-5-sonnet-20241022",
reason: "Claude excels at producing maintainable code",
keywords: ["write code", "create function", "implement", "build", "develop", "program"],
patterns: [/write.*code/i, /create.*function/i, /implement.*class/i],
priority: 100,
},
{
taskType: "code_review",
model: "claude-3-5-sonnet-20241022",
reason: "Claude is strong at code review, debugging, and refactoring",
keywords: ["review code", "find bugs", "optimize", "refactor", "improve code", "debug"],
patterns: [/review.*code/i, /find.*bug/i, /refactor/i, /optimize/i],
priority: 95,
},
{
taskType: "math_reasoning",
model: "o1-mini",
reason: "o1-mini is optimized for mathematical reasoning",
keywords: ["calculate", "solve", "compute", "equation", "formula", "math"],
patterns: [/solve.*equation/i, /calculate/i, /mathematical/i],
priority: 90,
},
{
taskType: "complex_reasoning",
model: "o1",
reason: "o1 is suited to deeper logic and step-by-step reasoning",
keywords: ["analyze", "reason", "logic", "deduce", "infer", "prove", "derive"],
patterns: [/step.*by.*step/i, /reasoning/i, /logical.*analysis/i],
priority: 85,
},
{
taskType: "document_analysis",
model: "gemini-2.5-pro",
reason: "Gemini has a very large context window for long documents",
keywords: ["summarize document", "analyze document", "extract from", "review document"],
patterns: [/summarize.*document/i, /analyze.*pdf/i, /extract.*information/i],
priority: 80,
contextThreshold: 50000,
},
{
taskType: "creative_writing",
model: "gpt-4o",
reason: "GPT-4o is effective for creative and engaging prose",
keywords: ["write story", "create content", "blog post", "article", "creative"],
patterns: [/write.*story/i, /creative.*writing/i, /blog.*post/i],
priority: 75,
},
{
taskType: "technical_documentation",
model: "claude-3-5-sonnet-20241022",
reason: "Claude is strong at technical writing and documentation",
keywords: ["document", "documentation", "api docs", "technical writing"],
patterns: [/write.*documentation/i, /create.*docs/i],
priority: 72,
},
{
taskType: "translation",
model: "gpt-4o-mini",
reason: "Translation is usually handled well by a cost-effective model",
keywords: ["translate", "translation", "convert to"],
patterns: [/translate.*to/i, /translation/i],
priority: 70,
},
{
taskType: "simple_chat",
model: "gpt-4o-mini",
reason: "GPT-4o-mini is fast and inexpensive for simple conversation",
keywords: ["hello", "hi", "how are you", "thanks", "thank you", "help"],
patterns: [/^(hi|hello|hey)/i, /how.*are.*you/i, /thank/i],
priority: 65,
},
{
taskType: "data_extraction",
model: "gpt-4o-mini",
reason: "Structured extraction works well on a cost-effective model",
keywords: ["extract", "parse", "get data from", "scrape", "pull data"],
patterns: [/extract.*from/i, /parse.*json/i, /get.*data/i],
priority: 60,
},
{
taskType: "factual_qa",
model: "gpt-4o-mini",
reason: "GPT-4o-mini is efficient for straightforward factual answers",
keywords: ["what is", "who is", "when did", "where is", "how many"],
patterns: [/^(what|who|when|where|how|why)\b/i],
priority: 55,
},
{
taskType: "chinese_language",
model: "moonshot-v1-32k",
reason: "Kimi is optimized for Chinese language understanding",
keywords: ["chinese"],
patterns: [/[\u4e00-\u9fff]/],
priority: 98,
},
];

/**
* Stores built-in and custom smart-routing task configurations.
*/
export class TaskRegistry {
private tasks: Map<string, TaskConfiguration> = new Map();

constructor(initialTasks: TaskConfiguration[] = BUILT_IN_TASKS) {
for (const task of initialTasks) {
this.registerTask(task.taskType, task);
}
}

/**
* Register or replace a task configuration.
*/
public registerTask(taskType: string, config: TaskRegistration): void {
const normalizedTaskType = this.normalizeTaskType(taskType);
const taskConfig: TaskConfiguration = {
...config,
taskType: normalizedTaskType,
};

this.validateTask(taskConfig);
this.tasks.set(normalizedTaskType, this.cloneTask(taskConfig));
}

/**
* Get a task configuration by type.
*/
public getTask(taskType: string): TaskConfiguration | undefined {
const task = this.tasks.get(this.normalizeTaskType(taskType));
return task ? this.cloneTask(task) : undefined;
}

/**
* Get all tasks sorted by priority, highest first.
*/
public getAllTasks(): TaskConfiguration[] {
return Array.from(this.tasks.values())
.map(task => this.cloneTask(task))
.sort((a, b) => (b.priority ?? 0) - (a.priority ?? 0));
}

/**
* Check whether a task type is registered.
*/
public hasTask(taskType: string): boolean {
return this.tasks.has(this.normalizeTaskType(taskType));
}

/**
* Validate a task configuration.
*/
public validateTask(task: TaskConfiguration): void {
if (!task.taskType || typeof task.taskType !== "string" || task.taskType.trim() === "") {
throw new Error("TokenFirewall TaskRegistry: taskType must be a non-empty string");
}

if (!task.model || typeof task.model !== "string" || task.model.trim() === "") {
throw new Error(`TokenFirewall TaskRegistry: model must be a non-empty string for task "${task.taskType}"`);
}

if (!task.reason || typeof task.reason !== "string" || task.reason.trim() === "") {
throw new Error(`TokenFirewall TaskRegistry: reason must be a non-empty string for task "${task.taskType}"`);
}

if (task.keywords !== undefined && !this.isStringArray(task.keywords)) {
throw new Error(`TokenFirewall TaskRegistry: keywords must be an array of strings for task "${task.taskType}"`);
}

if (task.patterns !== undefined && !task.patterns.every(pattern => pattern instanceof RegExp)) {
throw new Error(`TokenFirewall TaskRegistry: patterns must be an array of RegExp instances for task "${task.taskType}"`);
}

if (task.priority !== undefined && (!Number.isFinite(task.priority) || task.priority < 0)) {
throw new Error(`TokenFirewall TaskRegistry: priority must be a non-negative number for task "${task.taskType}"`);
}

if (task.contextThreshold !== undefined && (!Number.isFinite(task.contextThreshold) || task.contextThreshold < 0)) {
throw new Error(
`TokenFirewall TaskRegistry: contextThreshold must be a non-negative number for task "${task.taskType}"`
);
}
}

private normalizeTaskType(taskType: string): string {
return taskType.trim().toLowerCase();
}

private isStringArray(value: unknown): value is string[] {
return Array.isArray(value) && value.every(item => typeof item === "string" && item.trim() !== "");
}

private cloneTask(task: TaskConfiguration): TaskConfiguration {
return {
...task,
keywords: task.keywords ? [...task.keywords] : undefined,
patterns: task.patterns ? [...task.patterns] : undefined,
};
}
}

export const taskRegistry = new TaskRegistry();
export const builtInTasks = BUILT_IN_TASKS.map(task => ({ ...task }));
112 changes: 112 additions & 0 deletions tests/task-registry.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/**
* Task registry tests.
*
* Run: npm run build && node tests/task-registry.test.js
*/

const assert = require("assert");
const { TaskRegistry, taskRegistry } = require("../dist/index.js");

function test(name, fn) {
try {
fn();
console.log(`✓ ${name}`);
} catch (error) {
console.error(`✗ ${name}`);
throw error;
}
}

test("registers all built-in smart routing tasks", () => {
const tasks = taskRegistry.getAllTasks();
const taskTypes = tasks.map(task => task.taskType);

assert.ok(tasks.length >= 12);
assert.ok(taskTypes.includes("code_generation"));
assert.ok(taskTypes.includes("document_analysis"));
assert.ok(taskTypes.includes("chinese_language"));
assert.ok(taskTypes.includes("technical_documentation"));
});

test("returns tasks ordered by priority", () => {
const registry = new TaskRegistry([
{
taskType: "low_priority",
model: "gpt-4o-mini",
reason: "Low priority",
priority: 1,
},
{
taskType: "high_priority",
model: "gpt-4o",
reason: "High priority",
priority: 99,
},
]);

assert.deepStrictEqual(
registry.getAllTasks().map(task => task.taskType),
["high_priority", "low_priority"]
);
});

test("supports custom task registration", () => {
const registry = new TaskRegistry([]);

registry.registerTask("legal_analysis", {
model: "gpt-4o",
reason: "Complex legal reasoning",
keywords: ["legal", "contract", "clause"],
patterns: [/legal.*analysis/i],
priority: 8,
});

const task = registry.getTask("legal_analysis");
assert.ok(registry.hasTask("legal_analysis"));
assert.strictEqual(task.model, "gpt-4o");
assert.strictEqual(task.patterns[0].test("legal contract analysis"), true);
});

test("defensively copies returned task arrays", () => {
const registry = new TaskRegistry([]);

registry.registerTask("simple", {
model: "gpt-4o-mini",
reason: "Simple test task",
keywords: ["hello"],
patterns: [/hello/i],
});

const task = registry.getTask("simple");
task.keywords.push("mutated");
task.patterns.push(/mutated/i);

const freshTask = registry.getTask("simple");
assert.deepStrictEqual(freshTask.keywords, ["hello"]);
assert.strictEqual(freshTask.patterns.length, 1);
});

test("validates required task fields", () => {
const registry = new TaskRegistry([]);

assert.throws(
() => registry.registerTask("bad", { model: "", reason: "missing model" }),
/model must be a non-empty string/
);

assert.throws(
() => registry.registerTask("bad", { model: "gpt-4o", reason: "" }),
/reason must be a non-empty string/
);

assert.throws(
() => registry.registerTask("bad", {
model: "gpt-4o",
reason: "invalid patterns",
patterns: ["not-regex"],
}),
/patterns must be an array of RegExp/
);
});

console.log("Task registry tests passed.");