diff --git a/.context/eeglab-implementation-plan.md b/.context/eeglab-implementation-plan.md new file mode 100644 index 0000000..2cf45da --- /dev/null +++ b/.context/eeglab-implementation-plan.md @@ -0,0 +1,516 @@ +# EEGLab Community Implementation Plan + +**Date:** 2026-01-26 +**Status:** Planning + +## Overview + +Develop a comprehensive EEGLab community for OSA, providing researchers with access to EEGLab documentation, codebase knowledge, and 20+ years of community wisdom from the mailing list. + +## Resources Inventory + +### 1. Website & Documentation +- **Main Website:** [eeglab.org](https://eeglab.org) at https://github.com/sccn/sccn.github.io +- **Structure:** + - 11 core tutorial modules (data import → scripting) + - Plugin documentation (ICLabel, DIPFIT, LIMO, SIFT, NFT) + - Workshops and training materials + - FAQs and troubleshooting + - Integration guides (FieldTrip, HPC, Python, Octave) + - Citation and revision history + +### 2. GitHub Repositories (Top 20 by Stars) +Based on activity, stars, and community engagement: + +| Repo | Stars | Forks | Open Issues | Language | Priority | +|------|-------|-------|-------------|----------|----------| +| eeglab | 726 | 261 | 51 | MATLAB | **Critical** | +| labstreaminglayer | 706 | 181 | - | HTML | High | +| liblsl | 159 | 78 | - | C++ | High | +| ICLabel | 70 | 23 | - | MATLAB | High | +| clean_rawdata | 49 | 19 | - | MATLAB | Medium | +| roiconnect | 45 | 18 | - | HTML | Medium | +| PACTools | 42 | 10 | - | MATLAB | Medium | +| mobilab | 31 | 22 | - | MATLAB | Low | +| EEG-BIDS | 30 | 21 | - | MATLAB | Medium | +| eeglab_tutorial_scripts | 11 | 5 | - | MATLAB | Medium | + +**Recommended for Phase 1:** eeglab, ICLabel, clean_rawdata, EEG-BIDS, labstreaminglayer, liblsl + +### 3. Mailing List Archives +- **URL:** https://sccn.ucsd.edu/pipermail/eeglablist/ +- **Coverage:** 2004-2026 (22 years) +- **Organization:** Thread-based, subject-grouped, by-author, chronological +- **Peak Volume:** 2012 (5MB compressed) +- **Format:** HTML + compressed text files + +### 4. Key Papers +To track for citations and core knowledge: + +1. **Main EEGLab Paper:** [10.1016/j.jneumeth.2003.10.009](https://doi.org/10.1016/j.jneumeth.2003.10.009) + Delorme & Makeig (2004) - "EEGLAB: an open source toolbox for analysis of single-trial EEG dynamics" + +2. **ICLabel:** [10.1016/j.neuroimage.2019.05.026](https://doi.org/10.1016/j.neuroimage.2019.05.026) + Pion-Tonachini et al. (2019) - "ICLabel: An automated electroencephalographic independent component classifier" + +3. **PREP Pipeline:** (Need to find DOI) + +4. **Additional:** Search for: + - "EEGLAB tutorial" + - "EEGLAB plugin" + - "ICA EEG analysis" + - "EEG preprocessing" + +## Implementation Phases + +### Phase 1: Basic Community Setup (Week 1-2) + +**Goal:** Get EEGLab community running with documentation and GitHub repos + +1. **Create Community Config** (`src/assistants/eeglab/config.yaml`) + - Basic metadata (id: eeglab, name, description) + - System prompt tailored to EEG analysis workflows + - Documentation sources from sccn.github.io + - GitHub repos: eeglab, ICLabel, clean_rawdata, EEG-BIDS + - Paper queries and DOIs + +2. **Documentation Strategy** + - **Preloaded:** 2-3 core concepts (similar to HED) + - Installation/setup guide + - Basic tutorial structure + - Key concepts (ICA, artifact rejection, etc.) + - **On-demand:** Specific tutorials, plugin docs + +3. **Knowledge Tools** + - Reuse existing knowledge tools (search discussions, list recent, search papers) + - Same pattern as HED + +4. **Testing** + - Verify documentation retrieval + - Test GitHub sync + - Validate paper search + +**Deliverables:** +- [ ] `src/assistants/eeglab/config.yaml` +- [ ] Basic system prompt +- [ ] Documentation mapping +- [ ] GitHub sync working +- [ ] Manual testing with common questions + +### Phase 2: Docstring Extraction Tools (Week 3-4) + +**Goal:** Extract and index MATLAB/Python docstrings from codebases + +#### 2.1 MATLAB Docstring Extractor + +**Purpose:** Parse MATLAB files and extract function/script documentation + +**Implementation:** +- Create `src/tools/matlab_docstring_extractor.py` +- Strategy: + ```python + # Parse MATLAB file header comments + # Format: Lines starting with % before function definition + # Extract: + # - Function name + # - Purpose/description + # - Input parameters + # - Output parameters + # - Examples + # - See also references + ``` + +**Challenges:** +- MATLAB syntax variations +- Mixed comment styles +- Large codebase traversal + +**Solution:** +- Use regex patterns for MATLAB comment extraction +- Walk repository tree (recursive) +- Build searchable index +- Store in knowledge database + +**Tool Interface:** +```python +from langchain_core.tools import BaseTool + +def create_search_matlab_docs_tool( + community_id: str, + community_name: str, + repos: list[str], +) -> BaseTool: + """Search MATLAB function documentation from docstrings.""" + # Implementation +``` + +#### 2.2 Python Docstring Extractor + +**Purpose:** Extract Python docstrings (for Python-based EEG tools) + +**Implementation:** +- Create `src/tools/python_docstring_extractor.py` +- Use `ast` module to parse Python files +- Extract docstrings from: + - Functions + - Classes + - Methods + - Modules + +**Advantages over MATLAB:** +- Python AST parsing is built-in +- Standardized docstring formats (NumPy, Google, etc.) + +**Tool Interface:** +```python +def create_search_python_docs_tool( + community_id: str, + community_name: str, + repos: list[str], +) -> BaseTool: + """Search Python function documentation from docstrings.""" + # Implementation +``` + +#### 2.3 Integration with Sync System + +**Add to CLI:** +```bash +# Sync docstrings from GitHub repos +osa sync docstrings --community eeglab --language matlab +osa sync docstrings --community eeglab --language python + +# Or sync all +osa sync all --community eeglab +``` + +**Database Schema:** +```sql +CREATE TABLE IF NOT EXISTS docstrings ( + id INTEGER PRIMARY KEY, + community_id TEXT NOT NULL, + language TEXT NOT NULL, -- 'matlab' or 'python' + repo TEXT NOT NULL, + file_path TEXT NOT NULL, + symbol_name TEXT NOT NULL, -- function/class name + symbol_type TEXT NOT NULL, -- 'function', 'class', 'method' + docstring TEXT NOT NULL, + parameters TEXT, -- JSON array + returns TEXT, -- JSON object + examples TEXT, + indexed_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE(community_id, repo, file_path, symbol_name) +); + +CREATE INDEX idx_docstrings_search ON docstrings(community_id, symbol_name, docstring); +``` + +**Deliverables:** +- [ ] `src/tools/matlab_docstring_extractor.py` +- [ ] `src/tools/python_docstring_extractor.py` +- [ ] Database schema updates +- [ ] CLI sync commands +- [ ] LangChain tool wrappers +- [ ] Tests for both extractors + +### Phase 3: Mailing List FAQ Agent (Week 5-6) + +**Goal:** Summarize 22 years of mailing list discussions into searchable FAQ + +#### 3.1 Mailing List Scraper + +**Implementation:** +- Create `src/tools/mailman_scraper.py` +- Scrape HTML archives from https://sccn.ucsd.edu/pipermail/eeglablist/ +- Parse thread structure +- Extract: + - Thread title + - Original question + - Responses + - Thread metadata (date, participants) + +**Challenges:** +- 22 years of data (~5MB peak) +- Rate limiting +- HTML parsing consistency + +**Strategy:** +- Incremental scraping (year by year) +- Cache raw HTML locally +- Resume on failure + +#### 3.2 FAQ Summarization Agent + +**Purpose:** Use LLM to summarize Q&A threads into concise FAQ entries + +**Implementation:** +```python +# src/tools/faq_summarizer.py + +from src.core.services.llm import create_llm + +async def summarize_thread(thread_data: MailingListThread) -> FAQEntry: + """ + Use LLM to: + 1. Extract the core question + 2. Identify key responses + 3. Synthesize a concise answer + 4. Tag with categories + """ + llm = create_llm(model="qwen/qwen3-235b-a22b-2507") + + prompt = f""" + Summarize this EEGLab mailing list discussion into a FAQ entry. + + Thread: {thread_data.title} + Question: {thread_data.original_post} + Responses: {thread_data.responses} + + Extract: + 1. Core question (1-2 sentences) + 2. Best answer (2-3 paragraphs max) + 3. Related topics/tags + 4. Link to full thread + """ + + # Get summary + # Validate quality + # Store in database +``` + +**Cost Management:** +- Only summarize threads with >2 responses (indicates valuable discussion) +- Batch processing +- Use cheaper model for initial filtering, better model for final summary +- Cache summaries + +**Database Schema:** +```sql +CREATE TABLE IF NOT EXISTS mailing_list_faqs ( + id INTEGER PRIMARY KEY, + community_id TEXT NOT NULL, + thread_url TEXT NOT NULL UNIQUE, + thread_title TEXT NOT NULL, + question TEXT NOT NULL, + answer TEXT NOT NULL, + tags TEXT, -- JSON array + participants TEXT, -- JSON array + response_count INTEGER, + thread_date TEXT, + summarized_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE INDEX idx_faq_search ON mailing_list_faqs(community_id, question, answer, tags); +``` + +#### 3.3 FAQ Search Tool + +**Tool Interface:** +```python +def create_search_faq_tool( + community_id: str, + community_name: str, +) -> BaseTool: + """Search FAQ entries from mailing list history.""" + + def search_faq_impl(query: str, limit: int = 5) -> str: + # Full-text search on questions and answers + # Return: + # - Question + # - Answer summary + # - Link to full thread + + return StructuredTool.from_function(...) +``` + +**Deliverables:** +- [ ] `src/tools/mailman_scraper.py` +- [ ] `src/tools/faq_summarizer.py` +- [ ] Database schema +- [ ] CLI command: `osa sync mailing-list --community eeglab` +- [ ] LangChain FAQ search tool +- [ ] Cost estimation and budget +- [ ] Tests + +### Phase 4: Integration & Testing (Week 7) + +**Goal:** Integrate all components and test end-to-end + +1. **Update EEGLab Config** + - Add docstring search tool + - Add FAQ search tool + - Update system prompt with tool usage instructions + +2. **Frontend Widget** + - Test widget embedding on hypothetical EEGLab site + - Verify CORS settings + - Test model selection + +3. **Comprehensive Testing** + - Unit tests for each tool + - Integration tests for full workflow + - Test with real EEG researcher questions + - Performance testing (response times, DB queries) + +4. **Documentation** + - User guide for EEGLab assistant + - Developer guide for maintaining tools + - Sync workflow documentation + +**Deliverables:** +- [ ] Complete `src/assistants/eeglab/config.yaml` +- [ ] All tools integrated +- [ ] Test suite passing +- [ ] Documentation complete +- [ ] Demo video/screenshots + +### Phase 5: Deployment (Week 8) + +**Goal:** Deploy to production servers + +1. **Backend Deployment** + ```bash + # On hedtools server + cd ~/osa + git pull origin develop + deploy/deploy.sh dev # Test on dev first + # Verify + deploy/deploy.sh prod # Deploy to production + ``` + +2. **Knowledge Base Population** + ```bash + # Initialize database + osa sync init --community eeglab + + # Sync GitHub repos + osa sync github --community eeglab + + # Sync docstrings (may take hours) + osa sync docstrings --community eeglab --language matlab + osa sync docstrings --community eeglab --language python + + # Sync papers + osa sync papers --community eeglab + + # Sync mailing list (may take days) + osa sync mailing-list --community eeglab --batch-size 100 + ``` + +3. **Monitoring** + - Check LangFuse for usage + - Monitor response quality + - Gather user feedback + +**Deliverables:** +- [ ] Dev deployment tested +- [ ] Prod deployment complete +- [ ] Knowledge base populated +- [ ] Monitoring dashboard +- [ ] User feedback mechanism + +## Technical Decisions + +### Tool Architecture + +**Reuse Existing:** +- Knowledge discovery tools (search discussions, list recent, search papers) +- Documentation fetcher +- Base tool infrastructure + +**New Tools Needed:** +1. MATLAB docstring extractor +2. Python docstring extractor +3. FAQ search (mailing list) + +### Database Strategy + +**SQLite per community** (existing pattern): +- `knowledge/eeglab.db` +- Tables: + - `github_items` (existing) + - `papers` (existing) + - `docstrings` (new) + - `mailing_list_faqs` (new) + +### Cost Considerations + +**LLM Costs:** +1. **FAQ Summarization:** Most expensive + - Estimate: 22 years × 365 days × 5 threads/day = ~40,000 threads + - Filter to valuable threads (>2 responses): ~10,000 threads + - Cost per summary: ~$0.01 (with qwen3-235b-a22b) + - Total: ~$100-200 + +2. **Docstring Extraction:** Free + - No LLM needed, pure parsing + +3. **Regular Usage:** Standard rates + - Documentation retrieval: free (HTTP fetch) + - Tool calls: minimal cost + +**Strategy:** +- Run FAQ summarization as one-time batch job +- Incremental updates monthly +- Budget: $200 for initial setup + +### Performance Considerations + +**Indexing:** +- Full-text search on docstrings +- FTS5 for mailing list FAQ +- Regular PRAGMA optimize + +**Caching:** +- Documentation pages cached client-side +- Docstring index in memory for fast lookup + +## Open Questions + +1. **PREP Pipeline Paper:** Need to find the canonical DOI +2. **Additional Repos:** Should we include dipfit, cleanline, bva-io? +3. **Mailing List Rate Limiting:** What are the limits on pipermail archives? +4. **Custom EEGLab Tools:** Beyond general tools, do we need EEG-specific tools (e.g., channel location lookup, event code reference)? + +## Success Metrics + +1. **Coverage:** + - [ ] All major EEGLab tutorials indexed + - [ ] Top 6 repos synced + - [ ] 10,000+ FAQ entries + - [ ] 5,000+ docstrings indexed + +2. **Quality:** + - [ ] 90%+ accurate responses to common questions + - [ ] Proper citations to original sources + - [ ] Response time < 3 seconds + +3. **Adoption:** + - [ ] 100+ queries in first month + - [ ] Positive feedback from EEG researchers + - [ ] Integration with SCCN website (future) + +## Dependencies + +- [ ] Access to SCCN GitHub repos (public, no auth needed) +- [ ] Mailing list scraping allowed (check robots.txt) +- [ ] Budget approval for FAQ summarization (~$200) +- [ ] Test users for feedback (SCCN researchers) + +## Timeline Summary + +- **Week 1-2:** Basic setup (config, docs, GitHub sync) +- **Week 3-4:** Docstring extraction tools +- **Week 5-6:** Mailing list FAQ agent +- **Week 7:** Integration & testing +- **Week 8:** Deployment + +**Total:** 8 weeks for complete implementation + +## Next Steps + +1. Create GitHub epic/issue with this plan +2. Get approval from SCCN team +3. Clone necessary repos to ~/Documents/git/sccn/ +4. Start Phase 1 implementation +5. Set up weekly progress reviews diff --git a/frontend/osa-chat-widget.js b/frontend/osa-chat-widget.js index 3e39e98..d84ee0b 100644 --- a/frontend/osa-chat-widget.js +++ b/frontend/osa-chat-widget.js @@ -1227,19 +1227,45 @@ // Save chat history to localStorage let saveErrorShown = false; function saveHistory() { + if (!CONFIG.storageKey) { + console.warn('[OSA] Cannot save history - no storage key configured'); + return; + } + try { - localStorage.setItem(CONFIG.storageKey, JSON.stringify(messages)); + const data = JSON.stringify(messages); + localStorage.setItem(CONFIG.storageKey, data); saveErrorShown = false; } catch (e) { - console.error('Failed to save chat history:', e); - // Show error once per session to avoid spam - if (!saveErrorShown) { - const container = document.querySelector('.osa-chat-widget'); - if (container) { - showError(container, 'Chat history could not be saved. Storage may be full or disabled.'); - } - saveErrorShown = true; + console.error('[OSA] localStorage save failed:', { + errorName: e.name, + errorMessage: e.message, + messageCount: messages.length, + isQuotaError: e.name === 'QuotaExceededError' + }); + + // Determine error type for better user messaging + let errorMsg = 'Chat history could not be saved'; + const isQuotaError = e.name === 'QuotaExceededError'; + const isSecurityError = e.name === 'SecurityError'; + + if (isQuotaError) { + errorMsg = 'Storage full - conversation NOT saved. Clear browser data or export chat.'; + } else if (isSecurityError) { + errorMsg = 'Browser privacy settings prevent saving. Enable local storage.'; + } else { + errorMsg = 'Storage unavailable - conversation will be lost on refresh.'; + } + + // Show error (not just once - user needs to know every time save fails) + const container = document.querySelector('.osa-chat-widget'); + if (container && !saveErrorShown) { + showError(container, errorMsg); + saveErrorShown = true; // Show once per session to avoid spam } + + // Re-throw so callers know save failed + throw e; } } @@ -2091,10 +2117,16 @@ throw error; // Re-throw to be handled by sendMessage } finally { // Always release the reader to free resources - try { - reader.releaseLock(); - } catch (e) { - // Reader may already be closed, ignore errors + if (reader) { + try { + reader.releaseLock(); + } catch (releaseError) { + // Log cleanup failures - they indicate serious issues + console.error('[OSA] Failed to release stream reader:', { + errorName: releaseError.name, + errorMessage: releaseError.message + }); + } } } } @@ -2104,7 +2136,12 @@ if (isLoading || !question.trim()) return; isLoading = true; + + // Track message indices to avoid corruption on error + const userMessageIndex = messages.length; messages.push({ role: 'user', content: question }); + let assistantMessageCreated = false; + renderMessages(container); renderSuggestions(container); @@ -2157,6 +2194,7 @@ method: 'POST', headers: headers, body: JSON.stringify(body), + signal: AbortSignal.timeout(120000), // 2 minute timeout for connection + streaming }); if (!response.ok) { @@ -2185,6 +2223,7 @@ const contentType = response.headers.get('content-type') || ''; if (CONFIG.streamingEnabled && contentType.includes('text/event-stream')) { // Handle streaming response + assistantMessageCreated = true; // handleStreamingResponse creates assistant message await handleStreamingResponse(response, container); } else if (CONFIG.streamingEnabled && !contentType.includes('text/event-stream')) { // Streaming was expected but not received - log for debugging @@ -2243,19 +2282,30 @@ console.error('[OSA] Send message error:', error); showError(container, userMessage); - // Check if we need to clean up messages - // If handleStreamingResponse threw and kept partial content, assistant message is already in array - // If handleStreamingResponse threw and had no content, it already popped the assistant message - // We need to remove the user message only if it's the last message - const lastMessage = messages[messages.length - 1]; - if (lastMessage && lastMessage.role === 'user' && lastMessage.content === question) { - messages.pop(); + // Clean up messages based on what was created + // If streaming was attempted, handleStreamingResponse manages its own assistant message + // We only need to remove the user message if no assistant response exists + if (assistantMessageCreated) { + // handleStreamingResponse created an assistant message + // If it has content (partial or complete), keep both user and assistant messages + // If it has no content, handleStreamingResponse already removed it, so remove user message too + const lastMessage = messages[messages.length - 1]; + if (lastMessage && lastMessage.role === 'user' && messages.length === userMessageIndex + 1) { + // No assistant message remains, remove user message + messages.splice(userMessageIndex, 1); + } + } else { + // No streaming attempted, no assistant message created, remove user message + if (messages.length > userMessageIndex && messages[userMessageIndex].role === 'user') { + messages.splice(userMessageIndex, 1); + } } try { saveHistory(); } catch (saveError) { console.error('[OSA] Failed to save history after error:', saveError); + // saveHistory already showed error to user } updateStatusDisplay(false); } finally { diff --git a/src/api/routers/community.py b/src/api/routers/community.py index 3304503..a756561 100644 --- a/src/api/routers/community.py +++ b/src/api/routers/community.py @@ -1060,14 +1060,40 @@ async def _stream_ask_response( sse_event = {"event": "done"} yield f"data: {json.dumps(sse_event)}\n\n" + except HTTPException: + # Don't catch our own HTTP exceptions - let them propagate + raise + except ValueError as e: + # Input validation errors - user's fault + logger.warning("Invalid input in streaming for community %s: %s", community_id, e) + sse_event = { + "event": "error", + "message": f"Invalid request: {str(e)}", + "retryable": False, + } + yield f"data: {json.dumps(sse_event)}\n\n" except Exception as e: + # Unexpected errors - log with full context + import uuid + + error_id = str(uuid.uuid4()) logger.error( - "Streaming error in ask endpoint for community %s: %s", + "Unexpected streaming error (ID: %s) in ask endpoint for community %s: %s", + error_id, community_id, e, exc_info=True, + extra={ + "error_id": error_id, + "community_id": community_id, + "error_type": type(e).__name__, + }, ) - sse_event = {"event": "error", "message": str(e)} + sse_event = { + "event": "error", + "message": "An error occurred while generating the response. Please try again.", + "error_id": error_id, + } yield f"data: {json.dumps(sse_event)}\n\n" diff --git a/src/cli/sync.py b/src/cli/sync.py index 274c806..177889d 100644 --- a/src/cli/sync.py +++ b/src/cli/sync.py @@ -18,6 +18,7 @@ from src.assistants import registry from src.cli.config import load_config from src.knowledge.db import get_db_path, get_stats, init_db +from src.knowledge.docstring_sync import sync_repo_docstrings from src.knowledge.github_sync import sync_repo, sync_repos from src.knowledge.papers_sync import ( sync_all_papers, @@ -330,6 +331,80 @@ def sync_papers( console.print(f"\n[green]Total papers synced for {community}: {total}[/green]") +@sync_app.command("docstrings") +def sync_docstrings( + community: Annotated[ + str, + typer.Option("--community", "-c", help="Community ID to sync (e.g., hed, bids, eeglab)"), + ] = "hed", + language: Annotated[ + str, + typer.Option("--language", "-l", help="Language: matlab or python"), + ] = "matlab", + repo: Annotated[ + str | None, + typer.Option("--repo", "-r", help="Single repo to sync (owner/name format)"), + ] = None, + branch: Annotated[ + str, + typer.Option("--branch", "-b", help="Branch to sync from"), + ] = "main", +) -> None: + """Sync code docstrings from GitHub repositories. + + Extracts docstrings from MATLAB (.m) or Python (.py) files and indexes them + for search. If --repo is specified, syncs that single repo. Otherwise, syncs + all repos configured for the community. + """ + _require_admin() + _validate_community(community) + + if language not in ("matlab", "python"): + console.print("[red]Error: Language must be 'matlab' or 'python'[/red]") + raise typer.Exit(1) + + if not _safe_init_db(community): + raise typer.Exit(1) + + if repo: + # Sync single repo + try: + count = sync_repo_docstrings(repo, language, project=community, branch=branch) + console.print(f"\n[green]✓ Synced {count} {language} docstrings from {repo}[/green]") + except Exception as e: + console.print(f"[red]Error syncing {repo}: {e}[/red]") + logger.exception("Failed to sync docstrings from %s", repo) + raise typer.Exit(1) + else: + # Sync all repos from community config + repos = _get_community_repos(community) + if not repos: + console.print(f"[yellow]No repos configured for community '{community}'[/yellow]") + console.print("[dim]Use --repo to specify a repository explicitly[/dim]") + raise typer.Exit(1) + + console.print(f"[dim]Syncing {language} docstrings from {len(repos)} repos...[/dim]\n") + total = 0 + failed = [] + + for repo_name in repos: + try: + count = sync_repo_docstrings(repo_name, language, project=community, branch=branch) + total += count + if count > 0: + console.print(f" ✓ {repo_name}: {count} docstrings") + else: + console.print(f" - {repo_name}: no {language} files found") + except Exception as e: + console.print(f" ✗ {repo_name}: {e}") + logger.warning("Failed to sync docstrings from %s: %s", repo_name, e) + failed.append(repo_name) + + console.print(f"\n[green]Total {language} docstrings synced: {total}[/green]") + if failed: + console.print(f"[yellow]Failed repos ({len(failed)}): {', '.join(failed)}[/yellow]") + + @sync_app.command("all") def sync_all( community: Annotated[ diff --git a/src/core/services/litellm_llm.py b/src/core/services/litellm_llm.py index 47c5631..c5e6aac 100644 --- a/src/core/services/litellm_llm.py +++ b/src/core/services/litellm_llm.py @@ -28,11 +28,16 @@ ]) """ +import logging import os +from collections.abc import AsyncIterator, Iterator from typing import Any from langchain_core.language_models import BaseChatModel from langchain_core.messages import BaseMessage +from langchain_core.runnables import Runnable + +logger = logging.getLogger(__name__) def create_openrouter_llm( @@ -117,91 +122,430 @@ class CachingLLMWrapper(BaseChatModel): The cache_control parameter tells Anthropic to cache the content, reducing costs by 90% on cache hits (after initial 25% cache write premium). + Supports wrapping both direct LLMs (BaseChatModel) and tool-bound models + (RunnableBinding) to preserve caching through tool binding. When bind_tools() + is called, it returns a new CachingLLMWrapper around the RunnableBinding, + creating a chain: CachingLLMWrapper -> RunnableBinding -> BaseChatModel. + + This nested structure ensures cache_control markers are applied to all + invocations, including tool calls, preventing the 10x cost increase that + would occur if caching were bypassed. + Minimum cacheable prompt: 1024 tokens for Claude Sonnet/Opus, 4096 for Haiku 4.5 Cache TTL: 5 minutes (refreshed on each hit) """ - llm: BaseChatModel - """The underlying LLM to wrap.""" + llm: BaseChatModel | Runnable + """The underlying LLM or Runnable to wrap.""" model_config = {"arbitrary_types_allowed": True} - def __init__(self, llm: BaseChatModel, **kwargs): + def __init__(self, llm: BaseChatModel | Runnable, **kwargs): + """Initialize the caching wrapper. + + Args: + llm: The underlying LLM or Runnable to wrap + **kwargs: Additional arguments for BaseChatModel + + Raises: + ValueError: If llm is already a CachingLLMWrapper (prevents double-wrapping) + TypeError: If llm lacks required methods + """ + # Prevent wrapping a CachingLLMWrapper (infinite recursion risk) + if isinstance(llm, CachingLLMWrapper): + raise ValueError( + "Cannot wrap a CachingLLMWrapper with another CachingLLMWrapper. " + "This would create infinite recursion. If you need to bind tools, " + "call bind_tools() on the existing wrapper instead." + ) + + # Validate llm has required methods + if not hasattr(llm, "invoke"): + raise TypeError( + f"Cannot wrap {type(llm).__name__}: missing required 'invoke' method. " + "The LLM must implement at least the 'invoke' method." + ) + + logger.debug("Initialized CachingLLMWrapper wrapping %s", type(llm).__name__) super().__init__(llm=llm, **kwargs) @property def _llm_type(self) -> str: return "caching_llm_wrapper" - def bind_tools(self, tools: list, **kwargs): - """Bind tools to the underlying LLM. + def bind_tools(self, tools: list, **kwargs) -> "CachingLLMWrapper": + """Bind tools while preserving caching functionality. + + This method performs a two-step process: + 1. Delegates tool binding to the underlying LLM (returns RunnableBinding) + 2. Wraps the result in a new CachingLLMWrapper to preserve caching - Delegates tool binding to the underlying LLM and returns the bound model. - Note: Returns a RunnableBinding, not a CachingLLMWrapper, because - tool-bound models need to handle tool calls directly. + This ensures cache_control markers are applied to all invocations of the + tool-bound model, preventing the 10x cost increase that would occur if + caching were bypassed during tool calls. + + Args: + tools: List of tools to bind + **kwargs: Additional arguments for tool binding + + Returns: + New CachingLLMWrapper instance wrapping the tool-bound RunnableBinding + + Raises: + ValueError: If tools list is empty + NotImplementedError: If underlying LLM doesn't support tool binding + TypeError: If tool binding fails due to type issues """ - return self.llm.bind_tools(tools, **kwargs) + # Validate tools list + if not tools: + logger.error("Cannot bind empty tools list") + raise ValueError("Cannot bind empty tools list. Provide at least one tool to bind.") + + # Check if underlying LLM supports bind_tools + if not hasattr(self.llm, "bind_tools"): + logger.error("Underlying LLM %s does not support bind_tools", type(self.llm).__name__) + raise NotImplementedError( + f"Underlying LLM {type(self.llm).__name__} does not support tool binding. " + "Use a different LLM that implements bind_tools()." + ) + + try: + # Bind tools to underlying LLM + logger.debug("Binding %d tools to %s", len(tools), type(self.llm).__name__) + bound_llm = self.llm.bind_tools(tools, **kwargs) + + # Wrap in CachingLLMWrapper to preserve caching + wrapped_llm = CachingLLMWrapper(llm=bound_llm) + logger.debug("Successfully bound tools and wrapped in CachingLLMWrapper") + return wrapped_llm + + except NotImplementedError as e: + logger.error( + "Tool binding not implemented for %s: %s", + type(self.llm).__name__, + str(e), + ) + raise NotImplementedError( + f"Tool binding failed: {type(self.llm).__name__} does not implement bind_tools(). " + f"Original error: {str(e)}" + ) from e + except TypeError as e: + logger.error( + "Type error during tool binding for %s: %s", + type(self.llm).__name__, + str(e), + ) + raise TypeError( + f"Tool binding failed due to type mismatch: {str(e)}. " + "Check that tools are properly formatted LangChain tool objects." + ) from e + except ValueError as e: + logger.error( + "Value error during tool binding for %s: %s", + type(self.llm).__name__, + str(e), + ) + raise def _add_cache_control(self, messages: list[BaseMessage]) -> list[dict]: """Transform messages to add cache_control to system messages. + Applies cache_control markers only to SystemMessage instances. Other message + types (HumanMessage, AIMessage) are passed through unchanged. + + Validation is strict with fail-fast behavior: + - Messages must have a 'content' attribute (ValueError if missing) + - Message content must not be None (ValueError if None) + - Messages list must be a list, not None (ValueError/TypeError) + Args: messages: List of LangChain messages Returns: List of message dicts with cache_control on system messages + + Raises: + ValueError: If messages is None, contains messages without content, + or contains messages with None content + TypeError: If messages is not a list """ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage + # Validate input + if messages is None: + logger.error("Cannot transform None messages list") + raise ValueError("Messages list cannot be None") + + if not isinstance(messages, list): + logger.error("Expected list of messages, got %s", type(messages).__name__) + raise TypeError(f"Expected list of messages, got {type(messages).__name__}") + result = [] - for msg in messages: - if isinstance(msg, SystemMessage): - # Transform system message to multipart format with cache_control - result.append( - { - "role": "system", - "content": [ - { - "type": "text", - "text": msg.content, - "cache_control": {"type": "ephemeral"}, - } - ], - } + for i, msg in enumerate(messages): + try: + # Validate message has content attribute + if not hasattr(msg, "content"): + logger.error("Message at index %d missing content attribute", i) + raise ValueError( + f"Invalid message at index {i}: missing 'content' attribute. " + f"Message type: {type(msg).__name__}. All messages must have a 'content' attribute." + ) + + if isinstance(msg, SystemMessage): + # Validate content is not None + if msg.content is None: + logger.error("SystemMessage at index %d has None content", i) + raise ValueError( + f"SystemMessage at index {i} has None content. " + "All system messages must have non-None content." + ) + content = str(msg.content) + + # Transform system message to multipart format with cache_control + result.append( + { + "role": "system", + "content": [ + { + "type": "text", + "text": content, + "cache_control": {"type": "ephemeral"}, + } + ], + } + ) + logger.debug("Added cache_control to SystemMessage at index %d", i) + + elif isinstance(msg, HumanMessage): + if msg.content is None: + logger.error("HumanMessage at index %d has None content", i) + raise ValueError( + f"HumanMessage at index {i} has None content. " + "All messages must have non-None content." + ) + result.append({"role": "user", "content": str(msg.content)}) + + elif isinstance(msg, AIMessage): + if msg.content is None: + logger.error("AIMessage at index %d has None content", i) + raise ValueError( + f"AIMessage at index {i} has None content. " + "All messages must have non-None content." + ) + result.append({"role": "assistant", "content": str(msg.content)}) + + else: + # Fallback for other message types + logger.debug( + "Unknown message type %s at index %d, treating as user message", + type(msg).__name__, + i, + ) + if msg.content is None: + logger.error("Message at index %d has None content", i) + raise ValueError( + f"Message at index {i} has None content. " + "All messages must have non-None content." + ) + result.append({"role": "user", "content": str(msg.content)}) + + except (ValueError, AttributeError, UnicodeError) as e: + logger.error( + "Error processing message at index %d: %s (%s)", + i, + str(e), + type(e).__name__, + ) + raise + except Exception as e: + logger.error( + "Unexpected error processing message at index %d: %s (%s)", + i, + str(e), + type(e).__name__, + exc_info=True, ) - elif isinstance(msg, HumanMessage): - result.append({"role": "user", "content": msg.content}) - elif isinstance(msg, AIMessage): - result.append({"role": "assistant", "content": msg.content}) - else: - # Fallback for other message types - result.append({"role": "user", "content": str(msg.content)}) + raise + logger.debug( + "Transformed %d messages, added cache_control to %d system messages", + len(messages), + sum(1 for msg in messages if isinstance(msg, SystemMessage)), + ) return result def _generate(self, messages: list[BaseMessage], **kwargs) -> Any: """Generate response with cache_control on system messages.""" - cached_messages = self._add_cache_control(messages) - return self.llm._generate(cached_messages, **kwargs) + logger.debug("Generating response for %d messages", len(messages)) + try: + cached_messages = self._add_cache_control(messages) + return self.llm._generate(cached_messages, **kwargs) + except Exception as e: + logger.error( + "Error in _generate for %s: %s", + type(self.llm).__name__, + str(e), + exc_info=True, + ) + raise async def _agenerate(self, messages: list[BaseMessage], **kwargs) -> Any: """Async generate response with cache_control on system messages.""" - cached_messages = self._add_cache_control(messages) - return await self.llm._agenerate(cached_messages, **kwargs) + logger.debug("Async generating response for %d messages", len(messages)) + try: + cached_messages = self._add_cache_control(messages) + return await self.llm._agenerate(cached_messages, **kwargs) + except Exception as e: + logger.error( + "Error in _agenerate for %s: %s", + type(self.llm).__name__, + str(e), + exc_info=True, + ) + raise def invoke(self, messages: list[BaseMessage], **kwargs) -> Any: """Invoke LLM with cache_control on system messages.""" - cached_messages = self._add_cache_control(messages) - return self.llm.invoke(cached_messages, **kwargs) + logger.debug("Invoking %s with %d messages", type(self.llm).__name__, len(messages)) + try: + cached_messages = self._add_cache_control(messages) + return self.llm.invoke(cached_messages, **kwargs) + except Exception as e: + logger.error( + "Error invoking %s: %s", + type(self.llm).__name__, + str(e), + exc_info=True, + ) + raise async def ainvoke(self, messages: list[BaseMessage], **kwargs) -> Any: """Async invoke LLM with cache_control on system messages.""" - cached_messages = self._add_cache_control(messages) - return await self.llm.ainvoke(cached_messages, **kwargs) + logger.debug("Async invoking %s with %d messages", type(self.llm).__name__, len(messages)) + try: + cached_messages = self._add_cache_control(messages) + return await self.llm.ainvoke(cached_messages, **kwargs) + except Exception as e: + logger.error( + "Error async invoking %s: %s", + type(self.llm).__name__, + str(e), + exc_info=True, + ) + raise + + def stream(self, input: list[BaseMessage] | Any, config: Any = None, **kwargs) -> Iterator[Any]: + """Stream with cache_control applied to system messages. + + Applies cache_control transformation only if input is a list of messages. + Non-list inputs are passed through unchanged to the underlying LLM's stream method. + + Args: + input: Messages to stream (can be list of BaseMessage or other formats) + config: Optional runtime configuration + **kwargs: Additional arguments for streaming + + Yields: + Stream chunks from the underlying LLM + Raises: + ValueError: If input is None or invalid + NotImplementedError: If underlying LLM doesn't support streaming + Exception: Any exception raised by the underlying LLM's stream() method + """ + # Validate input + if input is None: + logger.error("Cannot stream with None input") + raise ValueError("Input cannot be None for streaming") + + # Check if underlying LLM supports streaming + if not (hasattr(self.llm, "stream") and callable(self.llm.stream)): + logger.error( + "Underlying LLM %s does not support streaming", + type(self.llm).__name__, + ) + raise NotImplementedError( + f"Underlying LLM {type(self.llm).__name__} does not support streaming. " + "To use streaming, either: (1) use a different LLM model that supports streaming, " + "or (2) use invoke() instead of stream() for non-streaming responses." + ) + + # Apply caching if input is a message list + if isinstance(input, list): + logger.debug("Applying cache_control to %d messages for streaming", len(input)) + input = self._add_cache_control(input) + else: + logger.warning( + "Input is not a message list (got %s), caching disabled for this stream. " + "This may result in higher API costs.", + type(input).__name__, + ) + + logger.debug("Starting stream from %s", type(self.llm).__name__) + return self.llm.stream(input, config=config, **kwargs) + + async def astream( + self, input: list[BaseMessage] | Any, config: Any = None, **kwargs + ) -> AsyncIterator[Any]: + """Async stream with cache_control applied to system messages. + + Applies cache_control transformation only if input is a list of messages. + Non-list inputs are passed through unchanged to the underlying LLM's astream method. + + Args: + input: Messages to stream (can be list of BaseMessage or other formats) + config: Optional runtime configuration + **kwargs: Additional arguments for streaming + + Yields: + Stream chunks from the underlying LLM -# Current Anthropic models (for reference) -# Note: Caching is enabled for ALL models by default; OpenRouter handles gracefully + Raises: + ValueError: If input is None or invalid + NotImplementedError: If underlying LLM doesn't support async streaming + Exception: Any exception raised by the underlying LLM's astream() method + """ + # Validate input + if input is None: + logger.error("Cannot async stream with None input") + raise ValueError("Input cannot be None for streaming") + + # Check if underlying LLM supports async streaming + if not (hasattr(self.llm, "astream") and callable(self.llm.astream)): + logger.error( + "Underlying LLM %s does not support async streaming", + type(self.llm).__name__, + ) + raise NotImplementedError( + f"Underlying LLM {type(self.llm).__name__} does not support async streaming. " + "To use streaming, either: (1) use a different LLM model that supports streaming, " + "or (2) use ainvoke() instead of astream() for non-streaming responses." + ) + + # Apply caching if input is a message list + if isinstance(input, list): + logger.debug( + "Applying cache_control to %d messages for async stream", + len(input), + ) + input = self._add_cache_control(input) + else: + logger.warning( + "Input is not a message list (got %s), caching disabled for this stream. " + "This may result in higher API costs.", + type(input).__name__, + ) + + logger.debug("Starting async stream from %s", type(self.llm).__name__) + async for chunk in self.llm.astream(input, config=config, **kwargs): + yield chunk + + +# Reference list of known Anthropic Claude models supporting prompt caching +# This is informational only - the is_cacheable_model() function uses a permissive +# heuristic (any "anthropic/claude-*" model) rather than this restrictive list. +# Caching is enabled by default for all models; OpenRouter/LiteLLM handle +# unsupported models gracefully by ignoring cache_control parameters. CACHEABLE_MODELS = { "claude-opus-4.5": "anthropic/claude-opus-4.5", "claude-sonnet-4.5": "anthropic/claude-sonnet-4.5", @@ -210,16 +554,23 @@ async def ainvoke(self, messages: list[BaseMessage], **kwargs) -> Any: def is_cacheable_model(model: str) -> bool: - """Check if a model supports Anthropic prompt caching. + """Check if a model identifier suggests Anthropic prompt caching support. + + Uses a heuristic check: returns True for model identifiers in the known + cacheable models list, or any identifier starting with "anthropic/claude-". + + Note: This is optimistic and may return True for models that don't actually + support caching. The LiteLLM/OpenRouter layer handles unsupported models + gracefully by ignoring cache_control parameters. Args: - model: Model identifier + model: Model identifier (e.g., "anthropic/claude-haiku-4.5") Returns: - True if the model supports cache_control + True if the model likely supports cache_control based on its identifier """ # Check exact match in aliases if model in CACHEABLE_MODELS: return True - # Check if it's an Anthropic Claude model + # Check if it's an Anthropic Claude model (permissive heuristic) return model.startswith("anthropic/claude-") diff --git a/src/knowledge/db.py b/src/knowledge/db.py index 2734b49..d2206bd 100644 --- a/src/knowledge/db.py +++ b/src/knowledge/db.py @@ -114,11 +114,53 @@ UNIQUE(source_type, source_name) ); +-- Docstrings extracted from source code +CREATE TABLE IF NOT EXISTS docstrings ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + repo TEXT NOT NULL, + file_path TEXT NOT NULL, + language TEXT NOT NULL, + symbol_name TEXT NOT NULL, + symbol_type TEXT NOT NULL, + docstring TEXT NOT NULL, + line_number INTEGER, + synced_at TEXT NOT NULL, + UNIQUE(repo, file_path, symbol_name) +); + +-- FTS5 virtual table for full-text search on docstrings +CREATE VIRTUAL TABLE IF NOT EXISTS docstrings_fts USING fts5( + symbol_name, + docstring, + content='docstrings', + content_rowid='id' +); + +-- Triggers to keep FTS in sync with docstrings +CREATE TRIGGER IF NOT EXISTS docstrings_ai AFTER INSERT ON docstrings BEGIN + INSERT INTO docstrings_fts(rowid, symbol_name, docstring) + VALUES (new.id, new.symbol_name, new.docstring); +END; + +CREATE TRIGGER IF NOT EXISTS docstrings_ad AFTER DELETE ON docstrings BEGIN + INSERT INTO docstrings_fts(docstrings_fts, rowid, symbol_name, docstring) + VALUES('delete', old.id, old.symbol_name, old.docstring); +END; + +CREATE TRIGGER IF NOT EXISTS docstrings_au AFTER UPDATE ON docstrings BEGIN + INSERT INTO docstrings_fts(docstrings_fts, rowid, symbol_name, docstring) + VALUES('delete', old.id, old.symbol_name, old.docstring); + INSERT INTO docstrings_fts(rowid, symbol_name, docstring) + VALUES (new.id, new.symbol_name, new.docstring); +END; + -- Indexes for efficient queries CREATE INDEX IF NOT EXISTS idx_github_items_repo ON github_items(repo); CREATE INDEX IF NOT EXISTS idx_github_items_status ON github_items(status); CREATE INDEX IF NOT EXISTS idx_github_items_type ON github_items(item_type); CREATE INDEX IF NOT EXISTS idx_papers_source ON papers(source); +CREATE INDEX IF NOT EXISTS idx_docstrings_repo ON docstrings(repo); +CREATE INDEX IF NOT EXISTS idx_docstrings_language ON docstrings(language); """ @@ -278,6 +320,48 @@ def upsert_paper( ) +def upsert_docstring( + conn: sqlite3.Connection, + *, + repo: str, + file_path: str, + language: str, + symbol_name: str, + symbol_type: str, + docstring: str, + line_number: int | None = None, +) -> None: + """Insert or update a docstring entry. + + Args: + conn: Database connection + repo: Repository in owner/name format + file_path: Relative path from repo root + language: 'matlab' or 'python' + symbol_name: Function/class/method name + symbol_type: 'function', 'class', 'method', 'script', 'module' + docstring: Full docstring text + line_number: Starting line in source file (optional) + """ + # Limit docstring size to prevent bloat + if len(docstring) > 10000: + docstring = docstring[:10000] + + conn.execute( + """ + INSERT INTO docstrings (repo, file_path, language, symbol_name, + symbol_type, docstring, line_number, synced_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(repo, file_path, symbol_name) DO UPDATE SET + docstring=excluded.docstring, + symbol_type=excluded.symbol_type, + line_number=excluded.line_number, + synced_at=excluded.synced_at + """, + (repo, file_path, language, symbol_name, symbol_type, docstring, line_number, _now_iso()), + ) + + def get_last_sync(source_type: str, source_name: str, project: str = "hed") -> str | None: """Get last sync time for a source. @@ -358,4 +442,13 @@ def get_stats(project: str = "hed") -> dict[str, int]: "SELECT COUNT(*) FROM papers WHERE source='pubmed'" ).fetchone()[0] + # Docstring stats + stats["docstrings_total"] = conn.execute("SELECT COUNT(*) FROM docstrings").fetchone()[0] + stats["docstrings_matlab"] = conn.execute( + "SELECT COUNT(*) FROM docstrings WHERE language='matlab'" + ).fetchone()[0] + stats["docstrings_python"] = conn.execute( + "SELECT COUNT(*) FROM docstrings WHERE language='python'" + ).fetchone()[0] + return stats diff --git a/src/knowledge/docstring_sync.py b/src/knowledge/docstring_sync.py new file mode 100644 index 0000000..290fa8b --- /dev/null +++ b/src/knowledge/docstring_sync.py @@ -0,0 +1,224 @@ +"""Docstring sync from GitHub repositories. + +Fetches source files from GitHub and extracts docstrings for indexing. +Supports MATLAB (.m) and Python (.py) files. +""" + +import logging +from typing import Literal + +import httpx +from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TextColumn + +from src.api.config import get_settings +from src.knowledge.db import get_connection, update_sync_metadata, upsert_docstring +from src.knowledge.matlab_parser import parse_matlab_file +from src.knowledge.python_parser import parse_python_file + +logger = logging.getLogger(__name__) +console = Console() + +GITHUB_API_BASE = "https://api.github.com" +GITHUB_RAW_BASE = "https://raw.githubusercontent.com" + + +def sync_repo_docstrings( + repo: str, + language: Literal["matlab", "python"], + project: str = "hed", + branch: str = "main", +) -> int: + """Sync docstrings from a GitHub repository. + + Args: + repo: Repository in owner/name format (e.g., 'sccn/eeglab') + language: 'matlab' or 'python' + project: Community ID for database isolation + branch: Git branch to sync + + Returns: + Number of docstrings extracted + + Raises: + httpx.HTTPStatusError: If GitHub API requests fail + """ + console.print(f"Syncing {language} docstrings from {repo} ({branch})...") + + # Determine file extension + extension = ".m" if language == "matlab" else ".py" + + # Get list of files from GitHub + files = _get_repo_files(repo, branch, extension) + console.print(f"Found {len(files)} {extension} files") + + if not files: + console.print(f"[yellow]No {extension} files found in {repo}[/yellow]") + return 0 + + # Process files and extract docstrings + total_docstrings = 0 + failed_files: list[tuple[str, str]] = [] + uncommitted = 0 + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("Processing files...", total=len(files)) + + with get_connection(project) as conn: + for file_path in files: + try: + # Fetch file content + content = _fetch_file_content(repo, branch, file_path) + + # Parse docstrings + if language == "matlab": + docstrings = parse_matlab_file(content, file_path) + else: + docstrings = parse_python_file(content, file_path) + + # Insert into database + for doc in docstrings: + upsert_docstring( + conn, + repo=repo, + file_path=file_path, + language=language, + symbol_name=doc.symbol_name, + symbol_type=doc.symbol_type, + docstring=doc.docstring, + line_number=doc.line_number, + ) + total_docstrings += 1 + uncommitted += 1 + + # Commit every 50 docstrings to avoid large transactions + if uncommitted >= 50: + conn.commit() + uncommitted = 0 + + except httpx.HTTPStatusError as e: + error_msg = f"HTTP {e.response.status_code}" + logger.error("HTTP error fetching %s: %s", file_path, e) + failed_files.append((file_path, error_msg)) + except httpx.TimeoutException: + logger.error("Timeout fetching %s", file_path) + failed_files.append((file_path, "Timeout")) + except SyntaxError as e: + logger.error("Syntax error in %s: %s", file_path, e) + failed_files.append((file_path, f"Syntax error: {e}")) + except UnicodeDecodeError as e: + logger.error("Encoding error in %s: %s", file_path, e) + failed_files.append((file_path, "Invalid encoding")) + except Exception as e: + logger.error("Unexpected error processing %s: %s", file_path, e, exc_info=True) + failed_files.append((file_path, f"Error: {type(e).__name__}")) + + progress.update(task, advance=1) + + # Final commit + conn.commit() + + # Update sync metadata + update_sync_metadata("docstrings", f"{repo}:{language}", total_docstrings, project) + + # Report results + console.print(f"[green]✓ Extracted {total_docstrings} docstrings[/green]") + + if failed_files: + console.print(f"\n[yellow]Warning: Failed to process {len(failed_files)} files:[/yellow]") + for path, error in failed_files[:10]: # Show first 10 + console.print(f" ✗ {path}: {error}") + if len(failed_files) > 10: + console.print(f" ... and {len(failed_files) - 10} more") + + return total_docstrings + + +def _get_repo_files(repo: str, branch: str, extension: str) -> list[str]: + """Get list of files with given extension from repository. + + Uses GitHub API with optional authentication for higher rate limits. + + Args: + repo: Repository in owner/name format + branch: Git branch + extension: File extension (e.g., '.py' or '.m') + + Returns: + List of file paths relative to repo root + + Raises: + httpx.HTTPStatusError: If API request fails + ValueError: If response format is unexpected + """ + settings = get_settings() + headers = { + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + } + + # Optional token for higher rate limits (60 req/hr -> 5000 req/hr) + if settings.github_token: + headers["Authorization"] = f"Bearer {settings.github_token}" + logger.debug("Using GitHub token for authentication") + + url = f"{GITHUB_API_BASE}/repos/{repo}/git/trees/{branch}?recursive=1" + + try: + response = httpx.get(url, headers=headers, timeout=30, follow_redirects=True) + response.raise_for_status() + except httpx.TimeoutException as e: + logger.error("Timeout fetching file tree from %s", repo) + raise TimeoutError( + f"GitHub request timed out after 30 seconds. Repo: {repo}, branch: {branch}" + ) from e + + try: + tree = response.json() + except ValueError as e: + logger.error("Invalid JSON from GitHub API for %s: %s", repo, e) + raise ValueError(f"GitHub returned invalid response for {repo}") from e + + if "tree" not in tree: + logger.error("Unexpected GitHub response format for %s: missing 'tree' key", repo) + raise ValueError(f"Unexpected response format from GitHub for {repo}") + + # Filter for files with the target extension + files = [ + item["path"] + for item in tree.get("tree", []) + if item.get("type") == "blob" and item["path"].endswith(extension) + ] + + return files + + +def _fetch_file_content(repo: str, branch: str, file_path: str) -> str: + """Fetch raw file content from GitHub. + + Args: + repo: Repository in owner/name format + branch: Git branch + file_path: File path relative to repo root + + Returns: + File content as string + + Raises: + httpx.HTTPStatusError: If request fails + TimeoutError: If request times out after 30 seconds + """ + url = f"{GITHUB_RAW_BASE}/{repo}/{branch}/{file_path}" + + try: + response = httpx.get(url, timeout=30, follow_redirects=True) + response.raise_for_status() + except httpx.TimeoutException as e: + logger.error("Timeout fetching %s from %s", file_path, repo) + raise TimeoutError(f"GitHub request timed out after 30 seconds. File: {file_path}") from e + + return response.text diff --git a/src/knowledge/matlab_parser.py b/src/knowledge/matlab_parser.py new file mode 100644 index 0000000..fd2e77b --- /dev/null +++ b/src/knowledge/matlab_parser.py @@ -0,0 +1,125 @@ +"""MATLAB docstring parser using regex. + +Extracts docstrings from MATLAB files for indexing and search. +Supports: functions (including nested) and scripts with header comments. + +MATLAB documentation conventions: +- Comments start with % +- Function help appears in comment block before the function definition +- Script help appears at the top of the file +""" + +import logging +import re +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + + +@dataclass +class MatlabDocstring: + """Parsed MATLAB docstring with metadata.""" + + symbol_name: str + symbol_type: str # 'function' or 'script' + docstring: str + line_number: int + + +def parse_matlab_file(content: str, file_path: str) -> list[MatlabDocstring]: + """Parse MATLAB file and extract all docstrings. + + Args: + content: File content as string + file_path: Path to the file (for module name extraction) + + Returns: + List of extracted docstrings with metadata + """ + results: list[MatlabDocstring] = [] + lines = content.split("\n") + + # Pattern to match function definitions + # Matches: + # function [out1, out2] = name(in1, in2) - multiple outputs + # function out = name(in1, in2) - single output + # function name(in1, in2) - no outputs + func_pattern = re.compile(r"^\s*function\s+(?:(?:\[[\w,\s]*\]|\w+)\s*=\s*)?(\w+)\s*\(") + + # Look for function definitions and their preceding comment blocks + for i, line in enumerate(lines): + match = func_pattern.match(line) + if match: + func_name = match.group(1) + + # Look backward for comment block (stop at first non-comment line) + comments = [] + j = i - 1 + while j >= 0: + stripped = lines[j].strip() + if stripped.startswith("%"): + # Remove comment marker and optional space + comment = re.sub(r"^\s*%+\s?", "", lines[j]) + comments.insert(0, comment) + j -= 1 + elif not stripped: + # Allow empty lines in comment block + j -= 1 + else: + # Hit non-comment, non-empty line + break + + if comments: + # Found docstring for this function + docstring = "\n".join(comments).strip() + # Calculate the line where the comment block starts + comment_start_line = i - len(comments) + 1 + results.append( + MatlabDocstring( + symbol_name=func_name, + symbol_type="function", + docstring=docstring, + line_number=comment_start_line, + ) + ) + + # If no functions found, check for script header comments + if not results: + script_comments = [] + for _i, line in enumerate(lines): + stripped = line.strip() + if stripped.startswith("%"): + comment = re.sub(r"^\s*%+\s?", "", line) + script_comments.append(comment) + elif stripped: + # Hit first non-comment, non-empty line + break + + if script_comments: + # This is a script with header documentation + script_name = _get_module_name(file_path) + docstring = "\n".join(script_comments).strip() + results.append( + MatlabDocstring( + symbol_name=script_name, + symbol_type="script", + docstring=docstring, + line_number=1, + ) + ) + + return results + + +def _get_module_name(file_path: str) -> str: + """Extract module name from file path. + + Args: + file_path: Path like 'functions/popfunc/pop_loadset.m' + + Returns: + Module name like 'pop_loadset' (without extension) + """ + import os + + return os.path.splitext(os.path.basename(file_path))[0] diff --git a/src/knowledge/python_parser.py b/src/knowledge/python_parser.py new file mode 100644 index 0000000..1f45860 --- /dev/null +++ b/src/knowledge/python_parser.py @@ -0,0 +1,104 @@ +"""Python docstring parser using AST. + +Extracts docstrings from Python files for indexing and search. +Supports: modules, functions, classes, and methods. +""" + +import ast +import logging +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + + +@dataclass +class PythonDocstring: + """Parsed Python docstring with metadata.""" + + symbol_name: str + symbol_type: str # 'function', 'class', 'method', 'module' + docstring: str + line_number: int + + +def parse_python_file(content: str, file_path: str) -> list[PythonDocstring]: + """Parse Python file and extract all docstrings. + + Args: + content: File content as string + file_path: Path to the file (for error reporting) + + Returns: + List of extracted docstrings with metadata + + Raises: + SyntaxError: If the Python file has syntax errors + """ + results: list[PythonDocstring] = [] + + # Let SyntaxError propagate to caller for proper error handling + tree = ast.parse(content) + + # Build parent map once for efficient method detection + parents: dict[ast.AST, ast.AST] = {} + for node in ast.walk(tree): + for child in ast.iter_child_nodes(node): + parents[child] = node + + # Module-level docstring + module_doc = ast.get_docstring(tree) + if module_doc: + results.append( + PythonDocstring( + symbol_name=_get_module_name(file_path), + symbol_type="module", + docstring=module_doc, + line_number=1, + ) + ) + + # Walk the AST and extract docstrings from functions and classes + for node in ast.walk(tree): + # Only process nodes that can have docstrings + if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef): + docstring = ast.get_docstring(node) + if docstring: + # Determine if it's a method or function using parent map + parent = parents.get(node) + symbol_type = "method" if isinstance(parent, ast.ClassDef) else "function" + results.append( + PythonDocstring( + symbol_name=node.name, + symbol_type=symbol_type, + docstring=docstring, + line_number=node.lineno, + ) + ) + + elif isinstance(node, ast.ClassDef): + docstring = ast.get_docstring(node) + if docstring: + results.append( + PythonDocstring( + symbol_name=node.name, + symbol_type="class", + docstring=docstring, + line_number=node.lineno, + ) + ) + + return results + + +def _get_module_name(file_path: str) -> str: + """Extract module name from file path. + + Args: + file_path: Path like 'mne/io/fiff/raw.py' + + Returns: + Module name like 'raw' (without extension) + """ + import os + + return os.path.splitext(os.path.basename(file_path))[0] diff --git a/src/knowledge/search.py b/src/knowledge/search.py index 6432320..cf7fa40 100644 --- a/src/knowledge/search.py +++ b/src/knowledge/search.py @@ -114,10 +114,10 @@ def search_github_items( status: str | None = None, repo: str | None = None, ) -> list[SearchResult]: - """Search GitHub issues and PRs using FTS5. + """Search GitHub issues and PRs using phrase matching. Args: - query: Search query (FTS5 syntax supported, e.g., "validation AND error") + query: Search phrase (treated as exact phrase, not FTS5 operators) project: Assistant/project name for database isolation. Defaults to 'hed'. limit: Maximum number of results item_type: Filter by 'issue' or 'pr' @@ -198,10 +198,10 @@ def search_papers( limit: int = 10, source: str | None = None, ) -> list[SearchResult]: - """Search papers using FTS5. + """Search papers using phrase matching. Args: - query: Search query (FTS5 syntax supported) + query: Search phrase (treated as exact phrase, not FTS5 operators) project: Assistant/project name for database isolation. Defaults to 'hed'. limit: Maximum number of results source: Filter by source ('openalex', 'semanticscholar', 'pubmed') @@ -384,3 +384,99 @@ def list_recent_github_items( raise return results + + +def search_docstrings( + query: str, + project: str = "hed", + limit: int = 10, + language: str | None = None, + repo: str | None = None, +) -> list[SearchResult]: + """Search code docstrings using phrase matching. + + Args: + query: Search phrase (treated as exact phrase, not FTS5 operators) + project: Assistant/project name for database isolation. Defaults to 'hed'. + limit: Maximum number of results + language: Filter by 'matlab' or 'python' + repo: Filter by repository name + + Returns: + List of matching results with GitHub source links, ordered by relevance + """ + sql = """ + SELECT d.symbol_name, d.docstring, d.file_path, d.repo, + d.language, d.symbol_type, d.line_number + FROM docstrings_fts f + JOIN docstrings d ON f.rowid = d.id + WHERE docstrings_fts MATCH ? + """ + params: list[str | int] = [query] + + if language: + sql += " AND d.language = ?" + params.append(language) + if repo: + sql += " AND d.repo = ?" + params.append(repo) + + sql += " ORDER BY rank LIMIT ?" + params.append(limit) + + results = [] + try: + with get_connection(project) as conn: + # Sanitize user query to prevent FTS5 injection + safe_query = _sanitize_fts5_query(query) + params[0] = safe_query + + for row in conn.execute(sql, params): + # Create snippet from docstring (first 200 chars) + docstring = row["docstring"] or "" + snippet = docstring[:200].strip() + if len(docstring) > 200: + snippet += "..." + + # Build GitHub URL to the specific line + file_path = row["file_path"] + repo_name = row["repo"] + line_number = row["line_number"] + # LIMITATION: Hardcoded to 'main' branch - links will break + # for repos using 'develop', 'master', or other default branches. + # TODO: Store branch name during sync and use it here. + github_url = f"https://github.com/{repo_name}/blob/main/{file_path}" + if line_number: + github_url += f"#L{line_number}" + + # Format title as "symbol_name (type) - file_path" + symbol_name = row["symbol_name"] + symbol_type = row["symbol_type"] + title = f"{symbol_name} ({symbol_type}) - {file_path}" + + results.append( + SearchResult( + title=title, + url=github_url, + snippet=snippet, + source=row["language"], + item_type=symbol_type, + status="documented", + created_at="", + ) + ) + except sqlite3.OperationalError as e: + # Infrastructure failure (corruption, disk full, permissions) - must propagate + logger.error( + "Database operational error during docstring search: %s", + e, + exc_info=True, + extra={"query": query, "project": project}, + ) + raise # Let API layer return 500, not empty results + except sqlite3.Error as e: + # Other database errors - still raise for debugging + logger.warning("Database error during docstring search '%s': %s", query, e) + raise + + return results diff --git a/src/tools/knowledge.py b/src/tools/knowledge.py index 5fb6d07..8c952cb 100644 --- a/src/tools/knowledge.py +++ b/src/tools/knowledge.py @@ -23,6 +23,7 @@ from src.knowledge.db import get_db_path from src.knowledge.search import ( list_recent_github_items, + search_docstrings, search_github_items, search_papers, ) @@ -252,6 +253,65 @@ def search_papers_impl(query: str, limit: int = 5) -> str: ) +def create_search_docstrings_tool( + community_id: str, + community_name: str, + language: str | None = None, +) -> BaseTool: + """Create a tool for searching code docstrings for a community. + + Args: + community_id: The community identifier (e.g., 'hed', 'bids', 'eeglab') + community_name: Display name (e.g., 'HED', 'BIDS', 'EEGLAB') + language: Optional language filter ('matlab' or 'python') + + Returns: + A LangChain tool for searching code documentation + """ + lang_help = "" + if language: + lang_help = f" Only searches {language.upper()} code." + else: + lang_help = " Searches both MATLAB and Python code." + + def search_docstrings_impl(query: str, limit: int = 5) -> str: + """Search code docstrings implementation.""" + if not _check_db_exists(community_id): + return ( + f"Knowledge database for {community_name} not initialized. " + "Run 'osa sync init' and 'osa sync docstrings' to populate it." + ) + + results = search_docstrings(query, project=community_id, limit=limit, language=language) + + if not results: + lang_str = f" ({language})" if language else "" + return f"No code documentation found for '{query}'{lang_str}." + + lines = [f"Code documentation in {community_name}:\n"] + for r in results: + lines.append(f"- {r.title}") + lines.append(f" [View source on GitHub]({r.url})") + if r.snippet: + snippet = r.snippet[:200] + "..." if len(r.snippet) > 200 else r.snippet + lines.append(f" Documentation: {snippet}") + lines.append("") + + return "\n".join(lines) + + description = ( + f"Search {community_name} code documentation (docstrings from functions, classes, scripts).{lang_help} " + "Use this to find how specific functions work, what parameters they accept, " + "and see usage examples. Results include direct links to source code on GitHub." + ) + + return StructuredTool.from_function( + func=search_docstrings_impl, + name=f"search_{community_id}_code_docs", + description=description, + ) + + def create_knowledge_tools( community_id: str, community_name: str, @@ -259,6 +319,8 @@ def create_knowledge_tools( include_discussions: bool = True, include_recent: bool = True, include_papers: bool = True, + include_docstrings: bool = False, + docstrings_language: str | None = None, ) -> list[BaseTool]: """Create all knowledge discovery tools for a community. @@ -266,12 +328,14 @@ def create_knowledge_tools( based on the community configuration. Args: - community_id: The community identifier (e.g., 'hed', 'bids') - community_name: Display name (e.g., 'HED', 'BIDS') + community_id: The community identifier (e.g., 'hed', 'bids', 'eeglab') + community_name: Display name (e.g., 'HED', 'BIDS', 'EEGLAB') repos: Optional list of GitHub repos for help text include_discussions: Include discussion search tool (default: True) include_recent: Include recent activity tool (default: True) include_papers: Include paper search tool (default: True) + include_docstrings: Include code docstring search tool (default: False) + docstrings_language: Filter docstrings by language ('matlab' or 'python') Returns: List of LangChain tools for the community @@ -287,4 +351,9 @@ def create_knowledge_tools( if include_papers: tools.append(create_search_papers_tool(community_id, community_name)) + if include_docstrings: + tools.append( + create_search_docstrings_tool(community_id, community_name, docstrings_language) + ) + return tools diff --git a/src/version.py b/src/version.py index b399e9a..c5f7bb7 100644 --- a/src/version.py +++ b/src/version.py @@ -1,7 +1,7 @@ """Version information for OSA.""" -__version__ = "0.5.2" -__version_info__ = (0, 5, 2) +__version__ = "0.5.3.dev0" +__version_info__ = (0, 5, 3, "dev") def get_version() -> str: diff --git a/tests/test_core/test_litellm_llm.py b/tests/test_core/test_litellm_llm.py new file mode 100644 index 0000000..0732fd6 --- /dev/null +++ b/tests/test_core/test_litellm_llm.py @@ -0,0 +1,469 @@ +"""Tests for LiteLLM integration and caching functionality. + +Testing Approach: +----------------- +This test suite uses a two-tier testing strategy to balance the NO MOCKS policy +with practical test requirements: + +1. **Unit Tests (FakeListChatModel)**: Test CachingLLMWrapper's internal mechanics + - Message transformation logic (_add_cache_control) + - Input validation and error handling + - Wrapper initialization and double-wrapping prevention + + These tests use FakeListChatModel (a LangChain test utility) because they're + testing the WRAPPER's logic, not LLM behavior. The wrapper mechanics are + deterministic and don't require real API calls. + +2. **Integration Tests (Real API)**: Test actual LLM behavior with caching + - Real Anthropic API calls with @pytest.mark.llm marker + - Tool binding with actual models + - Streaming responses with cache_control + + These tests verify the wrapper works correctly with real LLMs and that + cache_control parameters are properly transmitted. + +This separation allows fast unit tests for wrapper logic while ensuring real +LLM integration is thoroughly tested. The FakeListChatModel is NOT used to +test LLM responses or behavior - only wrapper mechanics. +""" + +import pytest +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +from langchain_core.tools import tool + +from src.core.services.litellm_llm import CachingLLMWrapper, create_openrouter_llm + + +# Test tool for tool binding tests +@tool +def calculator(expression: str) -> str: + """Calculate a mathematical expression. + + Args: + expression: A mathematical expression to evaluate (basic arithmetic only) + + Returns: + The result of the calculation + """ + import ast + import operator + + # Safe operators mapping + safe_operators = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.Mod: operator.mod, + ast.Pow: operator.pow, + ast.UAdd: operator.pos, + ast.USub: operator.neg, + } + + def safe_eval(node): + """Safely evaluate an AST node with only basic arithmetic.""" + if isinstance(node, ast.Constant): + return node.value + elif isinstance(node, ast.BinOp): + left = safe_eval(node.left) + right = safe_eval(node.right) + op = safe_operators.get(type(node.op)) + if op is None: + raise ValueError(f"Unsupported operator: {type(node.op).__name__}") + return op(left, right) + elif isinstance(node, ast.UnaryOp): + operand = safe_eval(node.operand) + op = safe_operators.get(type(node.op)) + if op is None: + raise ValueError(f"Unsupported operator: {type(node.op).__name__}") + return op(operand) + else: + raise ValueError(f"Unsupported expression: {type(node).__name__}") + + try: + # Parse the expression into an AST + tree = ast.parse(expression, mode="eval") + # Evaluate using only safe operations + result = safe_eval(tree.body) + return str(result) + except Exception as e: + return f"Error: {str(e)}" + + +class TestCachingLLMWrapperInitialization: + """Test CachingLLMWrapper initialization and validation.""" + + def test_prevents_double_wrapping(self): + """Verify double-wrapping is prevented with clear error.""" + from langchain_community.chat_models import FakeListChatModel + + fake_llm = FakeListChatModel(responses=["Test"]) + wrapper = CachingLLMWrapper(llm=fake_llm) + + # Attempt to wrap a wrapper should raise ValueError + with pytest.raises(ValueError) as exc_info: + CachingLLMWrapper(llm=wrapper) + + assert "Cannot wrap a CachingLLMWrapper" in str(exc_info.value) + assert "infinite recursion" in str(exc_info.value) + + def test_validates_llm_has_invoke_method(self): + """Verify initialization requires invoke method.""" + + # Create a mock object without invoke method + class InvalidLLM: + pass + + with pytest.raises(TypeError) as exc_info: + CachingLLMWrapper(llm=InvalidLLM()) + + assert "missing required 'invoke' method" in str(exc_info.value) + + +class TestCachingLLMWrapperMessageTransformation: + """Test message transformation with cache_control markers.""" + + def test_add_cache_control_transforms_system_message(self): + """Verify system message gets cache_control marker.""" + from langchain_community.chat_models import FakeListChatModel + + fake_llm = FakeListChatModel(responses=["Test response"]) + wrapper = CachingLLMWrapper(llm=fake_llm) + + messages = [ + SystemMessage(content="You are a helpful assistant."), + HumanMessage(content="Hello"), + ] + + result = wrapper._add_cache_control(messages) + + # Check system message structure + assert result[0]["role"] == "system" + assert isinstance(result[0]["content"], list) + assert result[0]["content"][0]["type"] == "text" + assert result[0]["content"][0]["text"] == "You are a helpful assistant." + assert result[0]["content"][0]["cache_control"] == {"type": "ephemeral"} + + # Check human message unchanged + assert result[1]["role"] == "user" + assert result[1]["content"] == "Hello" + + def test_add_cache_control_preserves_other_messages(self): + """Verify non-system messages unchanged.""" + from langchain_community.chat_models import FakeListChatModel + + fake_llm = FakeListChatModel(responses=["Test"]) + wrapper = CachingLLMWrapper(llm=fake_llm) + + messages = [ + HumanMessage(content="User query"), + AIMessage(content="Assistant response"), + ] + + result = wrapper._add_cache_control(messages) + + assert result[0]["role"] == "user" + assert result[0]["content"] == "User query" + assert result[1]["role"] == "assistant" + assert result[1]["content"] == "Assistant response" + # No cache_control on non-system messages + assert "cache_control" not in str(result[0]) + assert "cache_control" not in str(result[1]) + + def test_add_cache_control_handles_empty_list(self): + """Verify empty message list handled gracefully.""" + from langchain_community.chat_models import FakeListChatModel + + fake_llm = FakeListChatModel(responses=["Test"]) + wrapper = CachingLLMWrapper(llm=fake_llm) + + result = wrapper._add_cache_control([]) + + assert result == [] + + def test_add_cache_control_handles_multiple_system_messages(self): + """Verify multiple system messages all get cache_control.""" + from langchain_community.chat_models import FakeListChatModel + + fake_llm = FakeListChatModel(responses=["Test"]) + wrapper = CachingLLMWrapper(llm=fake_llm) + + messages = [ + SystemMessage(content="System prompt 1"), + SystemMessage(content="System prompt 2"), + HumanMessage(content="Query"), + ] + + result = wrapper._add_cache_control(messages) + + # Both system messages should have cache_control + assert result[0]["content"][0]["cache_control"] == {"type": "ephemeral"} + assert result[1]["content"][0]["cache_control"] == {"type": "ephemeral"} + + def test_add_cache_control_rejects_none_input(self): + """Verify None input raises ValueError.""" + from langchain_community.chat_models import FakeListChatModel + + fake_llm = FakeListChatModel(responses=["Test"]) + wrapper = CachingLLMWrapper(llm=fake_llm) + + with pytest.raises(ValueError) as exc_info: + wrapper._add_cache_control(None) + + assert "cannot be None" in str(exc_info.value) + + def test_add_cache_control_rejects_non_list_input(self): + """Verify non-list input raises TypeError.""" + from langchain_community.chat_models import FakeListChatModel + + fake_llm = FakeListChatModel(responses=["Test"]) + wrapper = CachingLLMWrapper(llm=fake_llm) + + with pytest.raises(TypeError) as exc_info: + wrapper._add_cache_control("not a list") + + assert "Expected list of messages" in str(exc_info.value) + + +class TestCachingLLMWrapperToolBinding: + """Test tool binding preserves caching. + + Note: Tool binding tests use real models in integration tests below, + since bind_tools() behavior is model-specific and hard to test in isolation. + """ + + def test_bind_tools_rejects_empty_list(self): + """Verify empty tools list raises ValueError.""" + from langchain_community.chat_models import FakeListChatModel + + fake_llm = FakeListChatModel(responses=["Test"]) + wrapper = CachingLLMWrapper(llm=fake_llm) + + with pytest.raises(ValueError) as exc_info: + wrapper.bind_tools([]) + + assert "empty tools list" in str(exc_info.value) + + def test_bind_tools_checks_for_method_support(self): + """Verify error when LLM doesn't support bind_tools.""" + from langchain_community.chat_models import FakeListChatModel + + # Check if FakeListChatModel has bind_tools + fake_llm = FakeListChatModel(responses=["Test"]) + if not hasattr(fake_llm, "bind_tools"): + # If it doesn't have bind_tools, test the error path + wrapper = CachingLLMWrapper(llm=fake_llm) + + with pytest.raises(NotImplementedError) as exc_info: + wrapper.bind_tools([calculator]) + + assert "does not support tool binding" in str(exc_info.value) + else: + # If it does have bind_tools, skip this test + pytest.skip("FakeListChatModel has bind_tools, cannot test missing method error") + + @pytest.mark.llm + def test_bind_tools_returns_caching_wrapper(self): + """Verify bind_tools() returns CachingLLMWrapper instance with real model.""" + import os + + api_key = os.getenv("OPENROUTER_API_KEY_FOR_TESTING") + if not api_key: + pytest.skip("OPENROUTER_API_KEY_FOR_TESTING not set") + + llm = create_openrouter_llm( + model="anthropic/claude-haiku-4.5", + api_key=api_key, + provider="Anthropic", + enable_caching=True, + ) + + # Verify initial wrapper + assert isinstance(llm, CachingLLMWrapper) + + # Bind tools + bound_model = llm.bind_tools([calculator]) + + # Verify result is still wrapped + assert isinstance(bound_model, CachingLLMWrapper) + + @pytest.mark.llm + def test_nested_bind_tools(self): + """Verify multiple bind_tools() calls work correctly with real model.""" + import os + + api_key = os.getenv("OPENROUTER_API_KEY_FOR_TESTING") + if not api_key: + pytest.skip("OPENROUTER_API_KEY_FOR_TESTING not set") + + llm = create_openrouter_llm( + model="anthropic/claude-haiku-4.5", + api_key=api_key, + provider="Anthropic", + enable_caching=True, + ) + + # First binding + bound_once = llm.bind_tools([calculator]) + assert isinstance(bound_once, CachingLLMWrapper) + + # Second binding should also work + bound_twice = bound_once.bind_tools([calculator]) + assert isinstance(bound_twice, CachingLLMWrapper) + + +# Invocation and streaming tests are covered by integration tests below +# since they require real model behavior that's hard to test in isolation + + +class TestCachingLLMWrapperIntegration: + """Integration tests with real API calls (requires OPENROUTER_API_KEY_FOR_TESTING).""" + + @pytest.mark.llm + def test_caching_wrapper_with_anthropic_model(self): + """End-to-end test with real Anthropic model.""" + import os + + api_key = os.getenv("OPENROUTER_API_KEY_FOR_TESTING") + if not api_key: + pytest.skip("OPENROUTER_API_KEY_FOR_TESTING not set") + + llm = create_openrouter_llm( + model="anthropic/claude-haiku-4.5", + api_key=api_key, + provider="Anthropic", + enable_caching=True, + ) + + messages = [ + SystemMessage(content="You are a helpful assistant. Always respond concisely."), + HumanMessage(content="Say 'Hello' and nothing else."), + ] + + response = llm.invoke(messages) + + # Verify response received + assert response is not None + assert hasattr(response, "content") + assert "hello" in response.content.lower() + + @pytest.mark.llm + def test_tool_binding_with_anthropic_model(self): + """End-to-end test with tools and Anthropic model.""" + import os + + api_key = os.getenv("OPENROUTER_API_KEY_FOR_TESTING") + if not api_key: + pytest.skip("OPENROUTER_API_KEY_FOR_TESTING not set") + + llm = create_openrouter_llm( + model="anthropic/claude-haiku-4.5", + api_key=api_key, + provider="Anthropic", + enable_caching=True, + ) + + bound_model = llm.bind_tools([calculator]) + + messages = [ + SystemMessage(content="You are a helpful calculator assistant."), + HumanMessage(content="What is 25 * 4?"), + ] + + response = bound_model.invoke(messages) + + # Verify response received (tool may or may not be called depending on model) + assert response is not None + assert hasattr(response, "content") + + @pytest.mark.llm + def test_streaming_with_caching(self): + """Verify streaming works with caching.""" + import os + + api_key = os.getenv("OPENROUTER_API_KEY_FOR_TESTING") + if not api_key: + pytest.skip("OPENROUTER_API_KEY_FOR_TESTING not set") + + llm = create_openrouter_llm( + model="anthropic/claude-haiku-4.5", + api_key=api_key, + provider="Anthropic", + enable_caching=True, + ) + + messages = [ + SystemMessage(content="You are a helpful assistant."), + HumanMessage(content="Count from 1 to 3."), + ] + + chunks = [] + for chunk in llm.stream(messages): + chunks.append(chunk) + + # Verify chunks received + assert len(chunks) > 0 + + # Assemble full response + full_response = "".join(str(chunk.content) for chunk in chunks if hasattr(chunk, "content")) + assert len(full_response) > 0 + + @pytest.mark.llm + async def test_async_invoke_with_caching(self): + """Verify async invoke works with caching.""" + import os + + api_key = os.getenv("OPENROUTER_API_KEY_FOR_TESTING") + if not api_key: + pytest.skip("OPENROUTER_API_KEY_FOR_TESTING not set") + + llm = create_openrouter_llm( + model="anthropic/claude-haiku-4.5", + api_key=api_key, + provider="Anthropic", + enable_caching=True, + ) + + messages = [ + SystemMessage(content="You are a helpful assistant. Always respond concisely."), + HumanMessage(content="Say 'Hello' and nothing else."), + ] + + response = await llm.ainvoke(messages) + + # Verify response received + assert response is not None + assert hasattr(response, "content") + assert "hello" in response.content.lower() + + @pytest.mark.llm + async def test_async_streaming_with_caching(self): + """Verify async streaming works with caching.""" + import os + + api_key = os.getenv("OPENROUTER_API_KEY_FOR_TESTING") + if not api_key: + pytest.skip("OPENROUTER_API_KEY_FOR_TESTING not set") + + llm = create_openrouter_llm( + model="anthropic/claude-haiku-4.5", + api_key=api_key, + provider="Anthropic", + enable_caching=True, + ) + + messages = [ + SystemMessage(content="You are a helpful assistant."), + HumanMessage(content="Count from 1 to 3."), + ] + + chunks = [] + async for chunk in llm.astream(messages): + chunks.append(chunk) + + # Verify chunks received + assert len(chunks) > 0 + + # Assemble full response + full_response = "".join(str(chunk.content) for chunk in chunks if hasattr(chunk, "content")) + assert len(full_response) > 0 diff --git a/tests/test_integration/test_docstring_workflow.py b/tests/test_integration/test_docstring_workflow.py new file mode 100644 index 0000000..449318f --- /dev/null +++ b/tests/test_integration/test_docstring_workflow.py @@ -0,0 +1,222 @@ +"""Integration tests for complete docstring workflow.""" + +import pytest + +from src.knowledge.db import get_db_path, get_stats, init_db +from src.knowledge.search import search_docstrings +from src.tools.knowledge import create_search_docstrings_tool + + +@pytest.fixture +def test_project(): + """Provide a test project name.""" + return "test-integration-docstrings" + + +@pytest.fixture +def clean_db(test_project): + """Ensure clean database for each test.""" + db_path = get_db_path(test_project) + if db_path.exists(): + db_path.unlink() + init_db(test_project) + yield test_project + # Cleanup after test + if db_path.exists(): + db_path.unlink() + + +def test_search_empty_database(clean_db): + """Test searching with empty database returns no results.""" + results = search_docstrings("test query", project=clean_db, limit=5) + assert len(results) == 0 + + +def test_search_with_data(clean_db): + """Test searching docstrings after inserting data.""" + from src.knowledge.db import get_connection, upsert_docstring + + # Insert test docstrings + with get_connection(clean_db) as conn: + upsert_docstring( + conn, + repo="test/repo", + file_path="test_func.m", + language="matlab", + symbol_name="test_function", + symbol_type="function", + docstring="This function tests the loadset functionality for EEG data", + line_number=1, + ) + upsert_docstring( + conn, + repo="test/repo", + file_path="helper.py", + language="python", + symbol_name="process_data", + symbol_type="function", + docstring="Process raw data and return cleaned results", + line_number=10, + ) + conn.commit() + + # Search for matlab function (simple single-word query) + results = search_docstrings("loadset", project=clean_db, limit=5, language="matlab") + assert len(results) == 1, f"Expected 1 result, got {len(results)}" + assert results[0].title == "test_function (function) - test_func.m" + assert "loadset" in results[0].snippet.lower() + + # Search for python function (simple query) + results = search_docstrings("process", project=clean_db, limit=5, language="python") + assert len(results) == 1 + assert "process_data" in results[0].title + + # Search without language filter (should find both) + results = search_docstrings("data", project=clean_db, limit=5) + assert len(results) == 2 + + +def test_tool_with_empty_db(clean_db): + """Test tool returns helpful message when database is empty.""" + tool = create_search_docstrings_tool(clean_db, "Test Community", language="matlab") + + # Search in empty database + result = tool.invoke({"query": "test", "limit": 5}) + + assert isinstance(result, str) + assert "No code documentation found" in result + + +def test_tool_with_data(clean_db): + """Test tool returns formatted results.""" + from src.knowledge.db import get_connection, upsert_docstring + + # Insert test docstring + with get_connection(clean_db) as conn: + upsert_docstring( + conn, + repo="sccn/eeglab", + file_path="functions/popfunc/pop_loadset.m", + language="matlab", + symbol_name="pop_loadset", + symbol_type="function", + docstring="pop_loadset() - load an EEG dataset", + line_number=5, + ) + conn.commit() + + # Create and invoke tool + tool = create_search_docstrings_tool(clean_db, "EEGLAB", language="matlab") + result = tool.invoke({"query": "loadset", "limit": 5}) + + # Verify formatted output + assert isinstance(result, str) + assert "pop_loadset" in result + assert "View source on GitHub" in result + assert "github.com" in result + + +def test_database_stats_after_insert(clean_db): + """Test that database stats reflect inserted docstrings.""" + from src.knowledge.db import get_connection, upsert_docstring + + # Initially empty + stats = get_stats(clean_db) + assert stats["docstrings_total"] == 0 + assert stats["docstrings_matlab"] == 0 + assert stats["docstrings_python"] == 0 + + # Insert docstrings + with get_connection(clean_db) as conn: + upsert_docstring( + conn, + repo="test/repo", + file_path="test.m", + language="matlab", + symbol_name="test1", + symbol_type="function", + docstring="Test docstring", + line_number=1, + ) + upsert_docstring( + conn, + repo="test/repo", + file_path="test.py", + language="python", + symbol_name="test2", + symbol_type="function", + docstring="Test docstring", + line_number=1, + ) + conn.commit() + + # Check updated stats + stats = get_stats(clean_db) + assert stats["docstrings_total"] == 2 + assert stats["docstrings_matlab"] == 1 + assert stats["docstrings_python"] == 1 + + +def test_upsert_updates_existing(clean_db): + """Test that upserting the same docstring updates it.""" + from src.knowledge.db import get_connection, upsert_docstring + + with get_connection(clean_db) as conn: + # Insert first version + upsert_docstring( + conn, + repo="test/repo", + file_path="test.m", + language="matlab", + symbol_name="my_func", + symbol_type="function", + docstring="Original docstring", + line_number=1, + ) + conn.commit() + + # Update with new docstring + upsert_docstring( + conn, + repo="test/repo", + file_path="test.m", + language="matlab", + symbol_name="my_func", + symbol_type="function", + docstring="Updated docstring", + line_number=1, + ) + conn.commit() + + # Should only have one entry with updated content + results = search_docstrings("docstring", project=clean_db) + assert len(results) == 1 + assert "Updated" in results[0].snippet + + +def test_docstring_size_limit(clean_db): + """Test that docstrings are truncated if too large.""" + from src.knowledge.db import get_connection, upsert_docstring + + # Create a very large docstring + large_doc = "x" * 15000 + + with get_connection(clean_db) as conn: + upsert_docstring( + conn, + repo="test/repo", + file_path="test.m", + language="matlab", + symbol_name="big_func", + symbol_type="function", + docstring=large_doc, + line_number=1, + ) + conn.commit() + + # Check that it was truncated + row = conn.execute( + "SELECT docstring FROM docstrings WHERE symbol_name='big_func'" + ).fetchone() + assert row is not None + assert len(row["docstring"]) <= 10000 diff --git a/tests/test_knowledge/test_matlab_parser.py b/tests/test_knowledge/test_matlab_parser.py new file mode 100644 index 0000000..536184b --- /dev/null +++ b/tests/test_knowledge/test_matlab_parser.py @@ -0,0 +1,235 @@ +"""Tests for MATLAB docstring parser.""" + +from src.knowledge.matlab_parser import parse_matlab_file + + +def test_parse_function_with_docstring(): + """Test extraction of function with preceding comment block.""" + code = """% This is a test function +% It does something useful +% +% Usage: +% result = test_func(input) +% +% Input: +% input - some input data +% +% Output: +% result - the processed result + +function result = test_func(input) + result = input * 2; +end +""" + results = parse_matlab_file(code, "test_func.m") + + assert len(results) == 1 + assert results[0].symbol_name == "test_func" + assert results[0].symbol_type == "function" + assert "This is a test function" in results[0].docstring + assert "Usage:" in results[0].docstring + assert results[0].line_number < 13 # Comment starts before function + + +def test_parse_function_multiple_outputs(): + """Test function with multiple return values.""" + code = """% Calculate sum and product +% of two numbers + +function [sum_result, prod_result] = calc(a, b) + sum_result = a + b; + prod_result = a * b; +end +""" + results = parse_matlab_file(code, "calc.m") + + assert len(results) == 1 + assert results[0].symbol_name == "calc" + assert results[0].symbol_type == "function" + assert "Calculate sum and product" in results[0].docstring + + +def test_parse_function_no_outputs(): + """Test function with no return values.""" + code = """% Display a message +% to the console + +function display_message(msg) + disp(msg); +end +""" + results = parse_matlab_file(code, "display_message.m") + + assert len(results) == 1 + assert results[0].symbol_name == "display_message" + assert "Display a message" in results[0].docstring + + +def test_parse_script_with_header(): + """Test script (no function) with header comments.""" + code = """% Script to plot data +% This script loads data and creates visualizations +% +% Requirements: +% - MATLAB R2020a or later +% - Statistics Toolbox + +data = load('data.mat'); +plot(data.x, data.y); +title('My Plot'); +""" + results = parse_matlab_file(code, "plot_script.m") + + assert len(results) == 1 + assert results[0].symbol_name == "plot_script" + assert results[0].symbol_type == "script" + assert "Script to plot data" in results[0].docstring + assert "Requirements:" in results[0].docstring + + +def test_parse_function_without_docstring(): + """Test function with no preceding comments.""" + code = """ +function result = simple_func(x) + result = x + 1; +end +""" + results = parse_matlab_file(code, "simple_func.m") + + assert len(results) == 0 + + +def test_parse_multiple_functions(): + """Test file with multiple functions.""" + code = """% Main function +% Does the main work + +function output = main_func(input) + output = helper_func(input); +end + +% Helper function +% Assists the main function + +function result = helper_func(data) + result = data * 2; +end +""" + results = parse_matlab_file(code, "multi_func.m") + + # Should find both functions + assert len(results) == 2 + names = {r.symbol_name for r in results} + assert "main_func" in names + assert "helper_func" in names + + +def test_parse_comment_styles(): + """Test different MATLAB comment styles.""" + code = """%% This is a function + % It has various comment styles + % Including indented comments + % And multiple % characters + +function result = test_comments(x) + result = x; +end +""" + results = parse_matlab_file(code, "test_comments.m") + + assert len(results) == 1 + # Docstring should have comment markers stripped + assert "This is a function" in results[0].docstring + assert "%" not in results[0].docstring.split("\n")[0] + + +def test_parse_empty_file(): + """Test empty file.""" + code = "" + results = parse_matlab_file(code, "empty.m") + + assert len(results) == 0 + + +def test_parse_comments_only_no_code(): + """Test file with only comments (script with no executable code).""" + code = """% This is just documentation +% No actual code here +""" + results = parse_matlab_file(code, "comments_only.m") + + # This should be treated as a script with header comments + assert len(results) == 1 + assert results[0].symbol_type == "script" + assert "This is just documentation" in results[0].docstring + + +def test_parse_function_with_blank_lines_in_comments(): + """Test function with blank lines in comment block.""" + code = """% Function to process data +% +% This function does something useful. +% +% Args: +% input - the input data + +function output = process_data(input) + output = input; +end +""" + results = parse_matlab_file(code, "process_data.m") + + assert len(results) == 1 + assert "Function to process data" in results[0].docstring + # Blank lines should be preserved + assert "\n\n" in results[0].docstring or results[0].docstring.count("\n") > 1 + + +def test_parse_real_eeglab_style(): + """Test EEGLAB-style function documentation.""" + code = """% pop_loadset() - load an EEG dataset +% +% Usage: +% >> EEGOUT = pop_loadset; +% >> EEGOUT = pop_loadset( filename, filepath); +% +% Inputs: +% filename - [string] dataset filename +% filepath - [string] dataset filepath +% +% Outputs: +% EEGOUT - output dataset structure +% +% See also: +% pop_saveset, eeg_checkset + +function [EEG, com] = pop_loadset(filename, filepath) + % function body +end +""" + results = parse_matlab_file(code, "pop_loadset.m") + + assert len(results) == 1 + assert results[0].symbol_name == "pop_loadset" + doc = results[0].docstring + assert "pop_loadset()" in doc + assert "Usage:" in doc + assert "Inputs:" in doc + assert "Outputs:" in doc + assert "See also:" in doc + + +def test_parse_percent_sign_in_text(): + """Test handling of % within comment text.""" + code = """% Function to calculate 50% threshold +% The threshold is set at 50% of max value + +function thresh = calc_threshold(data) + thresh = max(data) * 0.5; +end +""" + results = parse_matlab_file(code, "calc_threshold.m") + + assert len(results) == 1 + # The % in "50%" should be preserved in docstring + assert "50%" in results[0].docstring diff --git a/tests/test_knowledge/test_python_parser.py b/tests/test_knowledge/test_python_parser.py new file mode 100644 index 0000000..31eaa40 --- /dev/null +++ b/tests/test_knowledge/test_python_parser.py @@ -0,0 +1,189 @@ +"""Tests for Python docstring parser.""" + +import pytest + +from src.knowledge.python_parser import parse_python_file + + +def test_parse_module_docstring(): + """Test extraction of module-level docstring.""" + code = '''"""Module docstring for test file.""" + +def foo(): + pass +''' + results = parse_python_file(code, "test.py") + + assert len(results) == 1 + assert results[0].symbol_name == "test" + assert results[0].symbol_type == "module" + assert results[0].docstring == "Module docstring for test file." + assert results[0].line_number == 1 + + +def test_parse_function_docstring(): + """Test extraction of function docstring.""" + code = ''' +def my_function(x, y): + """Add two numbers. + + Args: + x: First number + y: Second number + + Returns: + Sum of x and y + """ + return x + y +''' + results = parse_python_file(code, "test.py") + + assert len(results) == 1 + assert results[0].symbol_name == "my_function" + assert results[0].symbol_type == "function" + assert "Add two numbers" in results[0].docstring + assert results[0].line_number > 0 + + +def test_parse_class_docstring(): + """Test extraction of class docstring.""" + code = ''' +class MyClass: + """A test class. + + This class does useful things. + """ + + def __init__(self): + pass +''' + results = parse_python_file(code, "test.py") + + assert len(results) == 1 + assert results[0].symbol_name == "MyClass" + assert results[0].symbol_type == "class" + assert "A test class" in results[0].docstring + + +def test_parse_method_docstring(): + """Test extraction of method docstrings.""" + code = ''' +class Calculator: + """Calculator class.""" + + def add(self, x, y): + """Add two numbers.""" + return x + y + + def subtract(self, x, y): + """Subtract y from x.""" + return x - y +''' + results = parse_python_file(code, "test.py") + + # Should find class and 2 methods + assert len(results) == 3 + + # Check class + class_doc = [r for r in results if r.symbol_type == "class"][0] + assert class_doc.symbol_name == "Calculator" + + # Check methods + method_docs = [r for r in results if r.symbol_type == "method"] + assert len(method_docs) == 2 + method_names = {m.symbol_name for m in method_docs} + assert "add" in method_names + assert "subtract" in method_names + + +def test_parse_async_function(): + """Test extraction of async function docstring.""" + code = ''' +async def fetch_data(url): + """Fetch data asynchronously. + + Args: + url: The URL to fetch from + """ + pass +''' + results = parse_python_file(code, "test.py") + + assert len(results) == 1 + assert results[0].symbol_name == "fetch_data" + assert results[0].symbol_type == "function" + assert "Fetch data asynchronously" in results[0].docstring + + +def test_parse_no_docstrings(): + """Test file with no docstrings.""" + code = """ +def foo(): + pass + +class Bar: + pass +""" + results = parse_python_file(code, "test.py") + + assert len(results) == 0 + + +def test_parse_mixed_documented_undocumented(): + """Test file with some documented and some undocumented items.""" + code = ''' +"""Module doc.""" + +def documented(): + """This has a docstring.""" + pass + +def undocumented(): + pass + +class DocumentedClass: + """Class with docstring.""" + pass + +class UndocumentedClass: + pass +''' + results = parse_python_file(code, "test.py") + + # Should find module, 1 function, and 1 class + assert len(results) == 3 + types = {r.symbol_type for r in results} + assert types == {"module", "function", "class"} + + +def test_parse_syntax_error(): + """Test handling of syntax errors.""" + code = ''' +def invalid syntax here: + """This won't parse.""" + pass +''' + # Should raise SyntaxError (not return empty list) + with pytest.raises(SyntaxError): + parse_python_file(code, "test.py") + + +def test_parse_nested_functions(): + """Test that nested functions are extracted.""" + code = ''' +def outer(): + """Outer function.""" + + def inner(): + """Inner function.""" + pass + + return inner +''' + results = parse_python_file(code, "test.py") + + # ast.walk() will find both functions + assert len(results) == 2 + names = {r.symbol_name for r in results} + assert "outer" in names + assert "inner" in names diff --git a/workers/osa-worker/index.js b/workers/osa-worker/index.js index ff5a87b..ab51f27 100644 --- a/workers/osa-worker/index.js +++ b/workers/osa-worker/index.js @@ -198,6 +198,29 @@ async function proxyToBackend(request, env, path, body, corsHeaders, CONFIG) { signal: AbortSignal.timeout(CONFIG.REQUEST_TIMEOUT), }); + // For non-2xx responses, pass through backend error details + if (!response.ok) { + let backendError = { error: `Backend returned ${response.status}` }; + const contentType = response.headers.get('Content-Type'); + + // Try to extract backend error message + try { + if (contentType?.includes('application/json')) { + backendError = await response.json(); + } else { + const text = await response.text(); + backendError = { error: text.substring(0, 500) }; + } + } catch { + // Use default error if parsing fails + } + + return new Response(JSON.stringify(backendError), { + status: response.status, + headers: { ...corsHeaders, 'Content-Type': 'application/json' }, + }); + } + // Check if streaming response const contentType = response.headers.get('Content-Type'); if (contentType?.includes('text/event-stream')) { @@ -217,6 +240,7 @@ async function proxyToBackend(request, env, path, body, corsHeaders, CONFIG) { headers: { ...corsHeaders, 'Content-Type': 'application/json' }, }); } catch (error) { + // Only network/proxy errors reach here, not HTTP errors console.error('Backend proxy error:', { path: path, errorName: error.name,