diff --git a/src/index.ts b/src/index.ts index c769e99..228cd8e 100644 --- a/src/index.ts +++ b/src/index.ts @@ -9,6 +9,7 @@ import { contextRegistry } from "./introspection/contextRegistry"; import { ModelRouter } from "./router/modelRouter"; import { ModelRouterOptions, ApiKeyConfig } from "./router/types"; import { apiKeyManager } from "./router/apiKeyManager"; +import { taskRegistry, TaskRegistry } from "./router/taskRegistry"; let globalBudgetManager: BudgetManager | null = null; let globalModelRouter: ModelRouter | null = null; @@ -227,6 +228,7 @@ export async function listModels(options: Omit & { + taskType?: string; +}; + +const BUILT_IN_TASKS: TaskConfiguration[] = [ + { + taskType: "code_generation", + model: "claude-3-5-sonnet-20241022", + reason: "Claude excels at producing maintainable code", + keywords: ["write code", "create function", "implement", "build", "develop", "program"], + patterns: [/write.*code/i, /create.*function/i, /implement.*class/i], + priority: 100, + }, + { + taskType: "code_review", + model: "claude-3-5-sonnet-20241022", + reason: "Claude is strong at code review, debugging, and refactoring", + keywords: ["review code", "find bugs", "optimize", "refactor", "improve code", "debug"], + patterns: [/review.*code/i, /find.*bug/i, /refactor/i, /optimize/i], + priority: 95, + }, + { + taskType: "math_reasoning", + model: "o1-mini", + reason: "o1-mini is optimized for mathematical reasoning", + keywords: ["calculate", "solve", "compute", "equation", "formula", "math"], + patterns: [/solve.*equation/i, /calculate/i, /mathematical/i], + priority: 90, + }, + { + taskType: "complex_reasoning", + model: "o1", + reason: "o1 is suited to deeper logic and step-by-step reasoning", + keywords: ["analyze", "reason", "logic", "deduce", "infer", "prove", "derive"], + patterns: [/step.*by.*step/i, /reasoning/i, /logical.*analysis/i], + priority: 85, + }, + { + taskType: "document_analysis", + model: "gemini-2.5-pro", + reason: "Gemini has a very large context window for long documents", + keywords: ["summarize document", "analyze document", "extract from", "review document"], + patterns: [/summarize.*document/i, /analyze.*pdf/i, /extract.*information/i], + priority: 80, + contextThreshold: 50000, + }, + { + taskType: "creative_writing", + model: "gpt-4o", + reason: "GPT-4o is effective for creative and engaging prose", + keywords: ["write story", "create content", "blog post", "article", "creative"], + patterns: [/write.*story/i, /creative.*writing/i, /blog.*post/i], + priority: 75, + }, + { + taskType: "technical_documentation", + model: "claude-3-5-sonnet-20241022", + reason: "Claude is strong at technical writing and documentation", + keywords: ["document", "documentation", "api docs", "technical writing"], + patterns: [/write.*documentation/i, /create.*docs/i], + priority: 72, + }, + { + taskType: "translation", + model: "gpt-4o-mini", + reason: "Translation is usually handled well by a cost-effective model", + keywords: ["translate", "translation", "convert to"], + patterns: [/translate.*to/i, /translation/i], + priority: 70, + }, + { + taskType: "simple_chat", + model: "gpt-4o-mini", + reason: "GPT-4o-mini is fast and inexpensive for simple conversation", + keywords: ["hello", "hi", "how are you", "thanks", "thank you", "help"], + patterns: [/^(hi|hello|hey)/i, /how.*are.*you/i, /thank/i], + priority: 65, + }, + { + taskType: "data_extraction", + model: "gpt-4o-mini", + reason: "Structured extraction works well on a cost-effective model", + keywords: ["extract", "parse", "get data from", "scrape", "pull data"], + patterns: [/extract.*from/i, /parse.*json/i, /get.*data/i], + priority: 60, + }, + { + taskType: "factual_qa", + model: "gpt-4o-mini", + reason: "GPT-4o-mini is efficient for straightforward factual answers", + keywords: ["what is", "who is", "when did", "where is", "how many"], + patterns: [/^(what|who|when|where|how|why)\b/i], + priority: 55, + }, + { + taskType: "chinese_language", + model: "moonshot-v1-32k", + reason: "Kimi is optimized for Chinese language understanding", + keywords: ["chinese"], + patterns: [/[\u4e00-\u9fff]/], + priority: 98, + }, +]; + +/** + * Stores built-in and custom smart-routing task configurations. + */ +export class TaskRegistry { + private tasks: Map = new Map(); + + constructor(initialTasks: TaskConfiguration[] = BUILT_IN_TASKS) { + for (const task of initialTasks) { + this.registerTask(task.taskType, task); + } + } + + /** + * Register or replace a task configuration. + */ + public registerTask(taskType: string, config: TaskRegistration): void { + const normalizedTaskType = this.normalizeTaskType(taskType); + const taskConfig: TaskConfiguration = { + ...config, + taskType: normalizedTaskType, + }; + + this.validateTask(taskConfig); + this.tasks.set(normalizedTaskType, this.cloneTask(taskConfig)); + } + + /** + * Get a task configuration by type. + */ + public getTask(taskType: string): TaskConfiguration | undefined { + const task = this.tasks.get(this.normalizeTaskType(taskType)); + return task ? this.cloneTask(task) : undefined; + } + + /** + * Get all tasks sorted by priority, highest first. + */ + public getAllTasks(): TaskConfiguration[] { + return Array.from(this.tasks.values()) + .map(task => this.cloneTask(task)) + .sort((a, b) => (b.priority ?? 0) - (a.priority ?? 0)); + } + + /** + * Check whether a task type is registered. + */ + public hasTask(taskType: string): boolean { + return this.tasks.has(this.normalizeTaskType(taskType)); + } + + /** + * Validate a task configuration. + */ + public validateTask(task: TaskConfiguration): void { + if (!task.taskType || typeof task.taskType !== "string" || task.taskType.trim() === "") { + throw new Error("TokenFirewall TaskRegistry: taskType must be a non-empty string"); + } + + if (!task.model || typeof task.model !== "string" || task.model.trim() === "") { + throw new Error(`TokenFirewall TaskRegistry: model must be a non-empty string for task "${task.taskType}"`); + } + + if (!task.reason || typeof task.reason !== "string" || task.reason.trim() === "") { + throw new Error(`TokenFirewall TaskRegistry: reason must be a non-empty string for task "${task.taskType}"`); + } + + if (task.keywords !== undefined && !this.isStringArray(task.keywords)) { + throw new Error(`TokenFirewall TaskRegistry: keywords must be an array of strings for task "${task.taskType}"`); + } + + if (task.patterns !== undefined && !task.patterns.every(pattern => pattern instanceof RegExp)) { + throw new Error(`TokenFirewall TaskRegistry: patterns must be an array of RegExp instances for task "${task.taskType}"`); + } + + if (task.priority !== undefined && (!Number.isFinite(task.priority) || task.priority < 0)) { + throw new Error(`TokenFirewall TaskRegistry: priority must be a non-negative number for task "${task.taskType}"`); + } + + if (task.contextThreshold !== undefined && (!Number.isFinite(task.contextThreshold) || task.contextThreshold < 0)) { + throw new Error( + `TokenFirewall TaskRegistry: contextThreshold must be a non-negative number for task "${task.taskType}"` + ); + } + } + + private normalizeTaskType(taskType: string): string { + return taskType.trim().toLowerCase(); + } + + private isStringArray(value: unknown): value is string[] { + return Array.isArray(value) && value.every(item => typeof item === "string" && item.trim() !== ""); + } + + private cloneTask(task: TaskConfiguration): TaskConfiguration { + return { + ...task, + keywords: task.keywords ? [...task.keywords] : undefined, + patterns: task.patterns ? [...task.patterns] : undefined, + }; + } +} + +export const taskRegistry = new TaskRegistry(); +export const builtInTasks = BUILT_IN_TASKS.map(task => ({ ...task })); diff --git a/tests/task-registry.test.js b/tests/task-registry.test.js new file mode 100644 index 0000000..8f550ae --- /dev/null +++ b/tests/task-registry.test.js @@ -0,0 +1,112 @@ +/** + * Task registry tests. + * + * Run: npm run build && node tests/task-registry.test.js + */ + +const assert = require("assert"); +const { TaskRegistry, taskRegistry } = require("../dist/index.js"); + +function test(name, fn) { + try { + fn(); + console.log(`✓ ${name}`); + } catch (error) { + console.error(`✗ ${name}`); + throw error; + } +} + +test("registers all built-in smart routing tasks", () => { + const tasks = taskRegistry.getAllTasks(); + const taskTypes = tasks.map(task => task.taskType); + + assert.ok(tasks.length >= 12); + assert.ok(taskTypes.includes("code_generation")); + assert.ok(taskTypes.includes("document_analysis")); + assert.ok(taskTypes.includes("chinese_language")); + assert.ok(taskTypes.includes("technical_documentation")); +}); + +test("returns tasks ordered by priority", () => { + const registry = new TaskRegistry([ + { + taskType: "low_priority", + model: "gpt-4o-mini", + reason: "Low priority", + priority: 1, + }, + { + taskType: "high_priority", + model: "gpt-4o", + reason: "High priority", + priority: 99, + }, + ]); + + assert.deepStrictEqual( + registry.getAllTasks().map(task => task.taskType), + ["high_priority", "low_priority"] + ); +}); + +test("supports custom task registration", () => { + const registry = new TaskRegistry([]); + + registry.registerTask("legal_analysis", { + model: "gpt-4o", + reason: "Complex legal reasoning", + keywords: ["legal", "contract", "clause"], + patterns: [/legal.*analysis/i], + priority: 8, + }); + + const task = registry.getTask("legal_analysis"); + assert.ok(registry.hasTask("legal_analysis")); + assert.strictEqual(task.model, "gpt-4o"); + assert.strictEqual(task.patterns[0].test("legal contract analysis"), true); +}); + +test("defensively copies returned task arrays", () => { + const registry = new TaskRegistry([]); + + registry.registerTask("simple", { + model: "gpt-4o-mini", + reason: "Simple test task", + keywords: ["hello"], + patterns: [/hello/i], + }); + + const task = registry.getTask("simple"); + task.keywords.push("mutated"); + task.patterns.push(/mutated/i); + + const freshTask = registry.getTask("simple"); + assert.deepStrictEqual(freshTask.keywords, ["hello"]); + assert.strictEqual(freshTask.patterns.length, 1); +}); + +test("validates required task fields", () => { + const registry = new TaskRegistry([]); + + assert.throws( + () => registry.registerTask("bad", { model: "", reason: "missing model" }), + /model must be a non-empty string/ + ); + + assert.throws( + () => registry.registerTask("bad", { model: "gpt-4o", reason: "" }), + /reason must be a non-empty string/ + ); + + assert.throws( + () => registry.registerTask("bad", { + model: "gpt-4o", + reason: "invalid patterns", + patterns: ["not-regex"], + }), + /patterns must be an array of RegExp/ + ); +}); + +console.log("Task registry tests passed.");