diff --git a/src/index.ts b/src/index.ts index c769e99..812703f 100644 --- a/src/index.ts +++ b/src/index.ts @@ -9,6 +9,7 @@ import { contextRegistry } from "./introspection/contextRegistry"; import { ModelRouter } from "./router/modelRouter"; import { ModelRouterOptions, ApiKeyConfig } from "./router/types"; import { apiKeyManager } from "./router/apiKeyManager"; +import { detectLanguage } from "./router/languageDetector"; let globalBudgetManager: BudgetManager | null = null; let globalModelRouter: ModelRouter | null = null; @@ -228,6 +229,8 @@ export async function listModels(options: Omit\s*[{(]/], + keywords: ["javascript", "node", "react", "promise", "async"] + }, + { + language: "python", + patterns: [/\bdef\s+\w+\s*\(/i, /\bimport\s+\w+/i, /\bclass\s+\w+:/i], + keywords: ["python", "pytest", "django", "flask", "pandas"] + }, + { + language: "go", + patterns: [/\bpackage\s+\w+/i, /\bfunc\s+\w+\s*\(/i, /\bgo\s+test\b/i], + keywords: ["golang", "goroutine", "gofmt", "interface{}", "go.mod"] + }, + { + language: "rust", + patterns: [/\bfn\s+\w+\s*\(/i, /\blet\s+mut\b/i, /\bimpl\s+\w+/i], + keywords: ["rust", "cargo", "borrow", "trait", "lifetime"] + }, + { + language: "sql", + patterns: [/\bselect\s+.+\s+from\b/i, /\binsert\s+into\b/i, /\bwhere\s+\w+\s*=/i], + keywords: ["sql", "postgres", "mysql", "sqlite", "query"] + }, + { + language: "shell", + patterns: [/#!\/(?:usr\/bin\/env\s+)?(?:ba|z)?sh/i, /\bnpm\s+(ci|run)\b/i, /\bdocker\s+compose\b/i], + keywords: ["bash", "shell", "terminal", "cli", "docker"] + }, + { + language: "markdown", + patterns: [/^#{1,6}\s+\S/m, /\[[^\]]+\]\([^)]+\)/, /```[\s\S]*?```/], + keywords: ["markdown", "readme", "mdx", "frontmatter", "table"] + } +]; + +export function detectLanguage(input: unknown): LanguageDetectionResult { + const text = extractText(input).slice(0, MAX_SCAN_LENGTH); + const normalized = text.toLowerCase(); + + if (!normalized.trim()) { + return { language: "unknown", confidence: 0, matchedSignals: [] }; + } + + const ranked = LANGUAGE_RULES + .map(rule => { + const matchedPatterns = rule.patterns + .filter(pattern => pattern.test(text)) + .map(pattern => pattern.source); + const matchedKeywords = rule.keywords + .filter(keyword => normalized.includes(keyword.toLowerCase())); + const score = matchedPatterns.length * 2 + matchedKeywords.length; + + return { + language: rule.language, + score, + matchedSignals: [...matchedKeywords, ...matchedPatterns] + }; + }) + .sort((a, b) => b.score - a.score); + + const best = ranked[0]; + if (!best || best.score === 0) { + return { language: "unknown", confidence: 0.15, matchedSignals: [] }; + } + + return { + language: best.language, + confidence: Math.min(0.98, 0.3 + best.score * 0.12), + matchedSignals: best.matchedSignals + }; +} + +function extractText(value: unknown): string { + if (!value) { + return ""; + } + + if (typeof value === "string") { + return value; + } + + if (Array.isArray(value)) { + return value.map(extractText).filter(Boolean).join(" "); + } + + if (typeof value !== "object") { + return ""; + } + + const record = value as Record; + return [ + record.prompt, + record.input, + record.query, + record.text, + record.content, + record.messages + ] + .map(extractText) + .filter(Boolean) + .join(" "); +} diff --git a/tests/language-detector.test.js b/tests/language-detector.test.js new file mode 100644 index 0000000..09e3690 --- /dev/null +++ b/tests/language-detector.test.js @@ -0,0 +1,50 @@ +const assert = require("assert"); +const { detectLanguage } = require("../dist/index.js"); + +const cases = [ + { + name: "detects TypeScript from typed declarations", + input: "interface User { id: string; active: boolean }", + language: "typescript" + }, + { + name: "detects Python from function syntax", + input: "def normalize_user(row):\n import pandas as pd\n return row", + language: "python" + }, + { + name: "detects Go from package and func syntax", + input: "package main\n\nfunc handler() error { return nil }", + language: "go" + }, + { + name: "detects SQL from query shape", + input: "SELECT id, email FROM users WHERE active = true", + language: "sql" + }, + { + name: "extracts nested chat messages", + input: { + messages: [ + { role: "user", content: "Please fix this pytest failure in my Flask app" } + ] + }, + language: "python" + } +]; + +for (const testCase of cases) { + const result = detectLanguage(testCase.input); + assert.equal(result.language, testCase.language, testCase.name); + assert(result.confidence > 0.3, `${testCase.name} should be confident`); + assert(result.matchedSignals.length > 0, `${testCase.name} should include signals`); +} + +const unknown = detectLanguage("Can you help me think through this vague idea?"); +assert.equal(unknown.language, "unknown"); +assert(unknown.confidence > 0); + +const bounded = detectLanguage(`${"x".repeat(20000)} def late_signal(): pass`); +assert.equal(bounded.language, "unknown"); + +console.log("language-detector tests passed");