diff --git a/docs/SLACK_SETUP.md b/docs/SLACK_SETUP.md new file mode 100644 index 0000000..8935b31 --- /dev/null +++ b/docs/SLACK_SETUP.md @@ -0,0 +1,129 @@ +# Slack API Integration Setup + +This guide explains how to set up the Slack API integration for the Google API Integration App. + +## Prerequisites + +1. A Slack workspace where you have admin permissions +2. Python environment with the required dependencies installed + +## Step 1: Create a Slack App + +1. Go to [https://api.slack.com/apps](https://api.slack.com/apps) +2. Click "Create New App" +3. Choose "From scratch" +4. Enter an app name (e.g., "Praga Core Integration") +5. Select your workspace +6. Click "Create App" + +## Step 2: Configure OAuth Scopes + +In your app's settings, go to "OAuth & Permissions" and add these Bot Token Scopes: + +### Required Scopes: +- `channels:history` - View messages in public channels +- `channels:read` - View basic information about public channels +- `groups:history` - View messages in private channels +- `groups:read` - View basic information about private channels +- `im:history` - View messages in direct messages +- `im:read` - View basic information about direct messages +- `mpim:history` - View messages in group direct messages +- `mpim:read` - View basic information about group direct messages +- `users:read` - View people in a workspace + +## Step 3: Install App to Workspace + +1. In "OAuth & Permissions", click "Install to Workspace" +2. Review the permissions and click "Allow" +3. Copy the "Bot User OAuth Token" (starts with `xoxb-`) + +## Step 4: Create Credentials File + +Create a credentials file with your app's client ID and secret: + +1. In your Slack app settings, go to "Basic Information" +2. Copy the "Client ID" and "Client Secret" +3. Create the file `~/.praga/secrets/slack_credentials.json`: + +```json +{ + "client_id": "your-client-id-here", + "client_secret": "your-client-secret-here" +} +``` + +**Important**: Keep this file secure and never commit it to version control! + +## Step 5: Test the Integration + +Run the app - it will automatically start the OAuth2 flow if no valid token exists: + +```bash +python app.py +``` + +**First Run**: The app will: +1. Display an authorization URL +2. Open your browser to authorize the app +3. Ask you to paste the redirect URL +4. Save the token for future use + +**Subsequent Runs**: The app will use the saved token automatically. + +Example queries: +- "Search conversations in channel general" +- "Find recent conversations from the last 3 days" +- "Search for threads containing 'meeting'" +- "Get conversation chunk C1234567890(0)" + +## API Usage Notes + +### Channel IDs +- Channel IDs start with 'C' (e.g., 'C1234567890') +- DM IDs start with 'D' (e.g., 'D1234567890') +- Group DM IDs start with 'G' (e.g., 'G1234567890') + +### Conversation Chunking +- Messages are automatically chunked by temporal proximity (1-hour windows by default) +- Each chunk has a unique ID in format: `{channel_id}({chunk_index})` +- Chunks are linked with next/previous relationships + +### Thread Handling +- Threads are identified by their `thread_ts` timestamp +- Thread data is cached after first access +- Only the parent message content is indexed for search + +### Authentication +- Uses OAuth2 flow with proper state management +- Tokens are stored securely in `~/.praga/secrets/` +- Automatic token refresh when needed +- No need to manually manage tokens + +### Caching +- All conversation and thread data is cached in memory +- No cache invalidation - data persists for the session +- Re-running the app will start with a fresh cache + +## Troubleshooting + +### Common Issues: + +1. **"missing_scope" error**: Make sure all required scopes are added to your app +2. **"channel_not_found" error**: Verify the channel ID is correct and the bot has access +3. **"not_in_channel" error**: Invite the bot to private channels before accessing them +4. **OAuth errors**: Ensure the credentials file exists and contains valid client_id/client_secret +5. **"invalid_auth" error**: Delete the token files in `~/.praga/secrets/` to restart OAuth flow + +### Getting Channel IDs: +- Right-click on a channel in Slack → "Copy link" +- The ID is in the URL: `/archives/C1234567890/` +- Or use the Slack API: `conversations.list` + +## Security Notes + +- Store your credentials file securely and never commit it to version control +- OAuth2 tokens are stored in `~/.praga/secrets/` with proper permissions +- The app can only access channels you've authorized during OAuth flow +- Private channels require explicit invitation of the app +- DMs require the user to have started a conversation with the app +- Tokens are automatically refreshed when needed \ No newline at end of file diff --git a/env.example b/env.example index 341c8bb..e404168 100644 --- a/env.example +++ b/env.example @@ -37,14 +37,14 @@ RETRIEVER_AGENT_MODEL=gpt-4o-mini RETRIEVER_MAX_ITERATIONS=10 # ============================================================================= -# GOOGLE API CONFIGURATION +# API CONFIGURATION # ============================================================================= # Path to your Google API credentials file (default: credentials.json) GOOGLE_CREDENTIALS_FILE=credentials.json -# Path to store the Google API token (default: token.pickle) -GOOGLE_TOKEN_FILE=token.pickle +SLACK_CLIENT_ID="your_slack_app_client_id" +SLACK_CLIENT_SECRET="your_slack_client_secret" # ============================================================================= # LOGGING CONFIGURATION diff --git a/requirements-pragweb.txt b/requirements-pragweb.txt index 2077343..c0feb31 100644 --- a/requirements-pragweb.txt +++ b/requirements-pragweb.txt @@ -6,3 +6,5 @@ tqdm>=4.65.0 python-dotenv>=1.1.0 bs4==0.0.2 chonkie>=1.0.0 +slack-sdk>=3.21.0 +uvicorn>=0.15.0 diff --git a/src/praga_core/global_context.py b/src/praga_core/global_context.py index a358f7f..0dad653 100644 --- a/src/praga_core/global_context.py +++ b/src/praga_core/global_context.py @@ -100,5 +100,5 @@ def __init__(self, api_client: Any = None, *args: Any, **kwargs: Any) -> None: @property def page_cache(self) -> PageCache: - """Access the global PageCache instance.""" + """Access the context's PageCache directly.""" return self.context.page_cache diff --git a/src/praga_core/page_cache.py b/src/praga_core/page_cache.py deleted file mode 100644 index 4c52c9d..0000000 --- a/src/praga_core/page_cache.py +++ /dev/null @@ -1,600 +0,0 @@ -"""SQL-based page cache for storing and retrieving Page instances. - -This module provides a PageCache class that automatically creates SQL tables -from Pydantic Page models and provides type-safe querying capabilities. -""" - -import logging -from datetime import datetime, timezone -from decimal import Decimal -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Type, - TypeVar, - Union, - get_args, - get_origin, -) - -from sqlalchemy import ( - JSON, - TIMESTAMP, - Boolean, - Column, - Float, - Integer, - Numeric, - String, - Text, - create_engine, -) -from sqlalchemy.engine import Engine -from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Session, sessionmaker -from sqlalchemy.orm.decl_api import declarative_base - -from .types import Page, PageURI - -logger = logging.getLogger(__name__) - -# TypeVar for generic Page type support -P = TypeVar("P", bound=Page) - -# SQLAlchemy declarative base for table definitions -Base = declarative_base() - -# Global registry to reuse table classes across PageCache instances -# This prevents SQLAlchemy warnings about duplicate table definitions -_TABLE_REGISTRY: Dict[str, Any] = {} - - -def _get_base_type(field_type: Any) -> Any: - """Extract the base type from a complex type annotation. - - Handles Optional/Union types and container types like List, Dict. - - Args: - field_type: The type annotation to analyze - - Returns: - The base type if extractable, None for complex container types - """ - # Handle Optional/Union types - origin = get_origin(field_type) - if origin is Union: - # Get non-None type from Optional/Union - args = get_args(field_type) - non_none_types = [t for t in args if t is not type(None)] - if len(non_none_types) == 1: - return non_none_types[0] - # If multiple non-None types, treat as complex type - return None - - # Handle container types (List, Dict, etc.) - if origin is not None: - # For now, treat all container types as JSON-serializable - return None - - return field_type - - -def _get_sql_type(field_type: Any, field_info: Any) -> Any: - """Map Python/Pydantic types to appropriate SQLAlchemy column types. - - Args: - field_type: The Python/Pydantic type annotation - field_info: The Pydantic field info object - - Returns: - Appropriate SQLAlchemy column type - - Examples: - str -> String (default) or Text (if sql_type="text" in field metadata) - int -> Integer - float -> Float - bool -> Boolean - datetime -> TIMESTAMP with timezone - Optional[int] -> Integer with nullable=True - List[str] -> JSON - Dict[str, Any] -> JSON - """ - # Get the base type (handling Optional, Union, etc) - base_type = _get_base_type(field_type) - - # If no base type found (complex type), use JSON - if base_type is None: - return JSON - - # Handle string types with optional metadata for TEXT vs VARCHAR - if base_type == str: - metadata = getattr(field_info, "json_schema_extra", {}) or {} - sql_type_name = metadata.get("sql_type", "string").lower() - if sql_type_name == "text": - return Text - return String - - # Handle PageURI as string (special case) - from .types import PageURI - - if base_type == PageURI: - return String - - # Map basic Python types to SQLAlchemy types - type_mapping = { - int: Integer, - float: Float, - bool: Boolean, - datetime: TIMESTAMP(timezone=True), - Decimal: Numeric, - dict: JSON, - list: JSON, - } - - sql_type = type_mapping.get(base_type) - if sql_type is None: - # Fallback to JSON for unknown types - logger.debug(f"Unknown type {base_type}, using JSON column") - return JSON - - return sql_type - - -def _get_page_schema_signature(page_class: Type[P]) -> str: - """Generate a signature string for the page schema to detect changes. - - This helps detect when a Page model's schema has changed between - different runs, which could require database migrations. - - Args: - page_class: The Page class to analyze - - Returns: - A string signature representing the schema - """ - fields = [] - for field_name, field in page_class.model_fields.items(): - if field_name not in ("uri", "id"): # Skip special fields - field_type = field.annotation - sql_type = _get_sql_type(field_type, field) - is_optional = get_origin(field_type) is Union and type(None) in get_args( - field_type - ) - fields.append(f"{field_name}:{sql_type.__class__.__name__}:{is_optional}") - - return "|".join(sorted(fields)) - - -def _create_page_table(page_class: Type[P]) -> Any: - """Create or reuse a SQLAlchemy table class from a Page class. - - This function automatically generates SQLAlchemy table classes based on - Pydantic Page models. It includes automatic type mapping and reuses - existing table classes to avoid SQLAlchemy warnings. - - Args: - page_class: The Page class to create a table for - - Returns: - SQLAlchemy table class (declarative model) - """ - page_type_name = page_class.__name__ - table_name = f"{page_type_name.lower()}_pages" - - # Check if we already have this table class in our registry - if page_type_name in _TABLE_REGISTRY: - existing_table_class = _TABLE_REGISTRY[page_type_name] - - # Check if schema has changed since last registration - current_signature = _get_page_schema_signature(page_class) - if hasattr(existing_table_class, "_schema_signature"): - if existing_table_class._schema_signature != current_signature: - logger.warning( - f"Schema change detected for {page_type_name}. " - f"Consider dropping and recreating the database or running migrations. " - f"Reusing existing table schema for now." - ) - - logger.debug(f"Reusing existing table class for {page_type_name}") - return existing_table_class - - # Create new table class dynamically - class_name = f"{page_type_name}Table" - - # Define base table attributes - attrs = { - "__tablename__": table_name, - "uri": Column(String, primary_key=True), # URI as primary key - "created_at": Column( - TIMESTAMP(timezone=True), default=lambda: datetime.now(timezone.utc) - ), - "updated_at": Column( - TIMESTAMP(timezone=True), - default=lambda: datetime.now(timezone.utc), - onupdate=lambda: datetime.now(timezone.utc), - ), - "_schema_signature": _get_page_schema_signature(page_class), - } - - # Add page fields as columns with appropriate SQL types - for field_name, field in page_class.model_fields.items(): - if field_name not in ("uri",): # Skip uri field - handled as primary key - field_type = field.annotation - sql_type = _get_sql_type(field_type, field) - - # Make column nullable if field is Optional - is_optional = get_origin(field_type) is Union and type(None) in get_args( - field_type - ) - attrs[field_name] = Column(sql_type, nullable=is_optional) - - # Create the table class dynamically - table_class = type(class_name, (Base,), attrs) - - # Register in our global registry to enable reuse - _TABLE_REGISTRY[page_type_name] = table_class - - logger.debug(f"Created new table class for {page_type_name}") - return table_class - - -class PageCache: - """SQL-based cache for storing and retrieving Page instances. - - This class provides automatic SQL table generation from Pydantic Page models, - with support for storing, retrieving, and querying pages using type-safe - SQLAlchemy expressions. - - Features: - - Automatic schema synthesis from Page models - - URI-based primary keys - - Type-safe querying with SQLAlchemy expressions - - Table reuse across cache instances - - Support for complex field types via JSON columns - - Example: - cache = PageCache("sqlite:///pages.db") - - # Pages are automatically registered when first stored - user = UserPage(uri=PageURI(...), name="John", email="john@example.com") - cache.store_page(user) - - # Query with type-safe expressions - users = cache.find_pages_by_attribute( - UserPage, - lambda t: t.email.like("%@company.com") - ) - """ - - def __init__(self, url: str, drop_previous: bool = False) -> None: - """Initialize the page cache. - - Args: - url: Database URL (e.g., "sqlite:///cache.db", "postgresql://...") - drop_previous: Whether to drop existing tables on initialization - """ - # Configure engine based on database type - engine_args = {} - if url.startswith("postgresql"): - from sqlalchemy.pool import NullPool - - engine_args["poolclass"] = NullPool - - self._engine = create_engine(url, **engine_args) - self._session = sessionmaker(bind=self.engine) - - # Instance-specific tracking of registered page types - self._registered_types: set[str] = set() - self._table_mapping: Dict[str, str] = {} - - # Reset database if requested - if drop_previous: - self._reset() - - def _reset(self) -> None: - """Reset the database by dropping all tables and clearing registries.""" - Base.metadata.drop_all(self.engine) - - # Clear registries for clean state - _TABLE_REGISTRY.clear() - self._registered_types.clear() - self._table_mapping.clear() - - logger.debug("Reset database and cleared all registries") - - def register_page_type(self, page_type: Type[P]) -> None: - """Register a page type for caching. - - This creates the necessary SQL table structure for the page type - if it doesn't already exist. - - Args: - page_type: Page class to register for caching - """ - type_name = page_type.__name__ - if type_name in self._registered_types: - return # Already registered in this instance - - # Create or reuse the table class - table_class = _create_page_table(page_type) - self._table_mapping[type_name] = table_class.__tablename__ - - # Create the table in the database if it doesn't exist - table_class.__table__.create(self.engine, checkfirst=True) - - # Mark as registered in this instance - self._registered_types.add(type_name) - - logger.debug(f"Registered page type {type_name}") - - def _convert_page_uris_for_storage(self, value: Any) -> Any: - """Convert PageURI objects to strings for database storage.""" - from .types import PageURI - - if isinstance(value, PageURI): - return str(value) - elif isinstance(value, list): - return [self._convert_page_uris_for_storage(item) for item in value] - elif isinstance(value, dict): - return {k: self._convert_page_uris_for_storage(v) for k, v in value.items()} - else: - return value - - def _convert_page_uris_from_storage(self, value: Any, field_type: Any) -> Any: - """Convert strings back to PageURI objects after database retrieval.""" - from .types import PageURI - - # Get the base type, handling Optional/Union - base_type = _get_base_type(field_type) - - if base_type == PageURI and isinstance(value, str): - return PageURI.parse(value) - elif get_origin(field_type) is list: - # Handle List[PageURI] - args = get_args(field_type) - if args and args[0] == PageURI and isinstance(value, list): - return [ - PageURI.parse(item) if isinstance(item, str) else item - for item in value - ] - elif isinstance(value, dict): - # Handle nested dictionaries (though less common for PageURI) - return { - k: self._convert_page_uris_from_storage(v, field_type) - for k, v in value.items() - } - - return value - - def store_page(self, page: Page) -> bool: - """Store a page in the cache. - - If the page already exists (same URI), it will be updated. - Otherwise, a new record will be created. - - Args: - page: Page instance to store - - Returns: - True if page was newly created, False if updated - """ - page_type_name = page.__class__.__name__ - if page_type_name not in self._registered_types: - self.register_page_type(page.__class__) - - table_class = _TABLE_REGISTRY[page_type_name] - - with self.get_session() as session: - # Check if page already exists - existing = session.query(table_class).filter_by(uri=str(page.uri)).first() - - if existing: - # Update existing page - for field_name in page.__class__.model_fields: - if field_name not in ("uri",): - value = getattr(page, field_name) - # Convert PageURI objects to strings - converted_value = self._convert_page_uris_for_storage(value) - setattr(existing, field_name, converted_value) - existing.updated_at = datetime.now(timezone.utc) - session.commit() - return False - else: - # Create new page record - page_data = {"uri": str(page.uri)} - for field_name in page.__class__.model_fields: - if field_name not in ("uri",): - value = getattr(page, field_name) - # Convert PageURI objects to strings - page_data[field_name] = self._convert_page_uris_for_storage( - value - ) - - page_entity = table_class(**page_data) - session.add(page_entity) - try: - session.commit() - return True - except IntegrityError: - session.rollback() - return False - - def get_page(self, page_type: Type[P], uri: PageURI) -> Optional[P]: - """Retrieve a page by its type and URI. - - Args: - page_type: The Page class type to retrieve - uri: The PageURI to look up - - Returns: - Page instance of the requested type if found, None otherwise - """ - page_type_name = page_type.__name__ - if page_type_name not in _TABLE_REGISTRY: - return None - - table_class = _TABLE_REGISTRY[page_type_name] - - with self.get_session() as session: - entity = session.query(table_class).filter_by(uri=str(uri)).first() - - if entity: - # Convert database entity back to Page instance - page_data = {"uri": PageURI.parse(entity.uri)} - for field_name, field_info in page_type.model_fields.items(): - if field_name not in ("uri",): - value = getattr(entity, field_name) - # Convert strings back to PageURI objects - converted_value = self._convert_page_uris_from_storage( - value, field_info.annotation - ) - page_data[field_name] = converted_value - - return page_type(**page_data) - return None - - def find_pages_by_attribute( - self, - page_type: Type[P], - query_filter: Callable[[Any], bool], - ) -> List[P]: - """Find pages of a given type that match a SQLAlchemy query filter. - - This method provides type-safe querying using SQLAlchemy expressions. - You can use lambda functions for simple queries or direct SQLAlchemy - expressions for more complex cases. - - Args: - page_type: Page class type to query - query_filter: Either a callable that takes the table class and returns - a filter expression, or a SQLAlchemy filter expression - - Returns: - List of matching Page instances of the requested type - - Examples: - # Simple equality check - users = cache.find_pages_by_attribute( - UserPage, - lambda t: t.email == "test@example.com" - ) - - # Pattern matching - users = cache.find_pages_by_attribute( - UserPage, - lambda t: t.name.ilike("%john%") - ) - - # Complex conditions - users = cache.find_pages_by_attribute( - UserPage, - lambda t: (t.age > 18) & (t.email.like("%@company.com")) - ) - - # Direct table reference (advanced) - table = cache._get_table_class(UserPage) - users = cache.find_pages_by_attribute( - UserPage, - table.email == "test@example.com" - ) - """ - page_type_name = page_type.__name__ - if page_type_name not in _TABLE_REGISTRY: - return [] - - table_class = _TABLE_REGISTRY[page_type_name] - - with self.get_session() as session: - query = session.query(table_class) - - # Apply the filter based on its type - if callable(query_filter): - # Lambda function that takes table class and returns filter - filter_expr = query_filter(table_class) - query = query.filter(filter_expr) - else: - # Direct SQLAlchemy filter expression - query = query.filter(query_filter) - - entities = query.all() - results = [] - - # Convert database entities back to Page instances - for entity in entities: - page_data = {"uri": PageURI.parse(entity.uri)} - for field_name, field_info in page_type.model_fields.items(): - if field_name not in ("uri",): - value = getattr(entity, field_name) - # Convert strings back to PageURI objects - converted_value = self._convert_page_uris_from_storage( - value, field_info.annotation - ) - page_data[field_name] = converted_value - - results.append(page_type(**page_data)) - - return results - - def _get_table_class(self, page_type: Type[P]) -> Any: - """Get the SQLAlchemy table class for a page type. - - This is useful for creating direct filter expressions when you need - more control over the query construction. - - Args: - page_type: Page class type - - Returns: - SQLAlchemy table class (declarative model) - - Raises: - ValueError: If page type is not registered - - Example: - table = cache._get_table_class(UserPage) - users = cache.find_pages_by_attribute( - UserPage, - table.email.in_(["user1@example.com", "user2@example.com"]) - ) - """ - page_type_name = page_type.__name__ - if page_type_name not in _TABLE_REGISTRY: - raise ValueError(f"Page type {page_type_name} not registered") - return _TABLE_REGISTRY[page_type_name] - - @property - def engine(self) -> Engine: - """Get the SQLAlchemy engine instance.""" - return self._engine - - def get_session(self) -> Session: - """Get a new database session. - - Returns: - SQLAlchemy session instance - - Note: - Remember to close the session when done, or use it in a context manager. - """ - return self._session() - - @property - def registered_page_types(self) -> List[str]: - """Get list of page type names registered in this cache instance. - - Returns: - List of registered page type names - """ - return list(self._registered_types) - - @property - def table_mapping(self) -> Dict[str, str]: - """Get mapping from page type names to database table names. - - Returns: - Dictionary mapping page type names to table names - """ - return self._table_mapping.copy() diff --git a/src/pragweb/app.py b/src/pragweb/app.py index 5fe3531..4b62ffc 100644 --- a/src/pragweb/app.py +++ b/src/pragweb/app.py @@ -11,6 +11,7 @@ from pragweb.google_api.docs import GoogleDocsService from pragweb.google_api.gmail import GmailService from pragweb.google_api.people import PeopleService +from pragweb.slack import SlackService logging.basicConfig(level=getattr(logging, get_current_config().log_level)) @@ -37,6 +38,7 @@ def setup_global_context() -> None: calendar_service = CalendarService(google_client) people_service = PeopleService(google_client) google_docs_service = GoogleDocsService(google_client) + slack_service = SlackService() # Collect all toolkits from registered services logger.info("Collecting toolkits...") @@ -45,6 +47,7 @@ def setup_global_context() -> None: calendar_service.toolkit, people_service.toolkit, google_docs_service.toolkit, + slack_service.toolkit, ] # Set up agent with collected toolkits diff --git a/src/pragweb/config.py b/src/pragweb/config.py index ba870c8..36e7895 100644 --- a/src/pragweb/config.py +++ b/src/pragweb/config.py @@ -39,6 +39,14 @@ class AppConfig(BaseModel): description="Path to Google API credentials file" ) + # Slack API Configuration + slack_client_id: Optional[str] = Field( + default=None, description="Slack app client ID" + ) + slack_client_secret: Optional[str] = Field( + default=None, description="Slack app client secret" + ) + # Logging Configuration log_level: str = Field(description="Logging level") @@ -89,6 +97,8 @@ def load_default_config() -> AppConfig: google_credentials_file=os.getenv( "GOOGLE_CREDENTIALS_FILE", "credentials.json" ), + slack_client_id=os.getenv("SLACK_CLIENT_ID"), + slack_client_secret=os.getenv("SLACK_CLIENT_SECRET"), log_level=os.getenv("LOG_LEVEL", "INFO"), ) diff --git a/src/pragweb/slack/__init__.py b/src/pragweb/slack/__init__.py new file mode 100644 index 0000000..ec297fc --- /dev/null +++ b/src/pragweb/slack/__init__.py @@ -0,0 +1,30 @@ +"""Slack service module.""" + +from .client import SlackAPIClient +from .ingestion import SlackIngestionService +from .page import ( + SlackChannelListPage, + SlackChannelPage, + SlackConversationPage, + SlackMessagePage, + SlackMessageSummary, + SlackThreadPage, + SlackUserPage, +) +from .service import SlackService, SlackToolkit +from .utils import SlackParser + +__all__ = [ + "SlackAPIClient", + "SlackIngestionService", + "SlackChannelListPage", + "SlackChannelPage", + "SlackConversationPage", + "SlackMessagePage", + "SlackMessageSummary", + "SlackThreadPage", + "SlackUserPage", + "SlackService", + "SlackToolkit", + "SlackParser", +] diff --git a/src/pragweb/slack/auth.py b/src/pragweb/slack/auth.py new file mode 100644 index 0000000..16b107f --- /dev/null +++ b/src/pragweb/slack/auth.py @@ -0,0 +1,559 @@ +"""Slack API authentication using singleton pattern.""" + +import logging +import os +import secrets +import subprocess +import tempfile +import threading +import time +import webbrowser +from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple +from urllib.parse import parse_qs + +import uvicorn +from slack_sdk import WebClient +from slack_sdk.errors import SlackApiError +from slack_sdk.oauth import AuthorizeUrlGenerator + +from pragweb.config import get_current_config +from pragweb.secrets_manager import SecretsManager, get_secrets_manager + +logger = logging.getLogger(__name__) + +_DEFAULT_SCOPES = [ + "channels:read", + "channels:history", + "groups:read", + "groups:history", + "im:read", + "im:history", + "mpim:read", + "mpim:history", + "users:read", + "users:read.email", + "search:read", +] + + +class SlackOAuthServer: + """Handles OAuth callback server for Slack authentication.""" + + def __init__(self, port: int = 8787): + self.port = port + self.redirect_uri = f"https://localhost:{port}/slack/oauth/callback" + + def run_oauth_flow(self, auth_url: str, expected_state: str) -> Dict[str, Any]: + """Run OAuth flow with callback server.""" + print("\nSlack OAuth Required:") + print(f"Starting HTTPS server on localhost:{self.port}") + print( + "Your browser will show a security warning for the self-signed certificate" + ) + print("Click 'Advanced' -> 'Proceed to localhost' to continue") + print("Opening browser for authorization...") + + # Create result container + oauth_result: Dict[str, Any] = {"code": None, "state": None, "error": None} + + # Create ASGI app + asgi_app = self._create_asgi_app(oauth_result, expected_state) + + # Start server in background thread + server_thread = threading.Thread( + target=lambda: self._run_https_server(asgi_app) + ) + server_thread.daemon = True + server_thread.start() + + # Give server time to start + time.sleep(3) + + # Open browser + try: + webbrowser.open(auth_url) + print(f"Browser opened to: {auth_url}") + except Exception as e: + print(f"Could not open browser automatically: {e}") + print(f"Please visit: {auth_url}") + + # Wait for callback + return self._wait_for_callback(oauth_result) + + def _wait_for_callback(self, oauth_result: Dict[str, Any]) -> Dict[str, Any]: + """Wait for OAuth callback with timeout.""" + print("Waiting for authorization...") + timeout = 180 # 3 minutes + start_time = time.time() + + while time.time() - start_time < timeout: + if oauth_result["error"]: + raise ValueError(f"OAuth failed: {oauth_result['error']}") + if oauth_result["code"]: + print("Authorization successful!") + return {"code": oauth_result["code"], "state": oauth_result["state"]} + time.sleep(0.5) + + raise ValueError("OAuth timeout - authorization not completed within 3 minutes") + + def _run_https_server(self, app: Callable[..., Any]) -> None: + """Run uvicorn with self-signed HTTPS certificate.""" + with tempfile.TemporaryDirectory() as temp_dir: + cert_file = os.path.join(temp_dir, "cert.pem") + key_file = os.path.join(temp_dir, "key.pem") + + # Generate self-signed certificate + subprocess.run( + [ + "openssl", + "req", + "-x509", + "-newkey", + "rsa:4096", + "-keyout", + key_file, + "-out", + cert_file, + "-days", + "1", + "-nodes", + "-subj", + "/CN=localhost", + ], + check=True, + capture_output=True, + ) + + # Run uvicorn with HTTPS + uvicorn.run( + app, + host="localhost", + port=self.port, + ssl_keyfile=key_file, + ssl_certfile=cert_file, + log_level="error", + ) + + def _create_asgi_app( + self, oauth_result: Dict[str, Any], expected_state: str + ) -> Callable[..., Awaitable[None]]: + """Create ASGI app to handle OAuth callbacks.""" + + async def app( + scope: Dict[str, Any], + receive: Callable[..., Any], + send: Callable[..., Awaitable[None]], + ) -> None: + if scope["type"] != "http": + return + + path = scope["path"] + query_string = scope.get("query_string", b"").decode() + query_params = parse_qs(query_string) + + if path == "/slack/oauth/callback": + await self._handle_oauth_callback( + oauth_result, expected_state, query_params, send + ) + elif path == "/health": + await self._handle_health(send) + else: + await self._handle_404(send) + + return app + + async def _handle_oauth_callback( + self, + oauth_result: Dict[str, Any], + expected_state: str, + query_params: Dict[str, List[str]], + send: Callable[..., Awaitable[None]], + ) -> None: + """Handle OAuth callback.""" + # Check for error + error = query_params.get("error", [None])[0] + if error: + error_desc = query_params.get("error_description", ["Unknown error"])[0] + oauth_result["error"] = f"{error}: {error_desc}" + + await send( + { + "type": "http.response.start", + "status": 400, + "headers": [(b"content-type", b"application/json")], + } + ) + await send( + { + "type": "http.response.body", + "body": b'{"status": "error", "message": "OAuth failed"}', + } + ) + return + + # Get code and state + code = query_params.get("code", [None])[0] + state = query_params.get("state", [None])[0] + + if not code: + oauth_result["error"] = "No authorization code received" + await self._send_error_response(send, "No authorization code received") + return + + if state != expected_state: + oauth_result["error"] = "Invalid state parameter" + await self._send_error_response(send, "Invalid state parameter") + return + + # Store results + oauth_result["code"] = code + oauth_result["state"] = state + + # Send success response + await self._send_success_response(send) + + async def _send_error_response( + self, send: Callable[..., Awaitable[None]], message: str + ) -> None: + """Send error response.""" + await send( + { + "type": "http.response.start", + "status": 400, + "headers": [(b"content-type", b"application/json")], + } + ) + await send( + { + "type": "http.response.body", + "body": f'{{"status": "error", "message": "{message}"}}'.encode(), + } + ) + + async def _send_success_response( + self, send: Callable[..., Awaitable[None]] + ) -> None: + """Send success response.""" + success_html = """ + +
You can now close this window and return to your application.
+ + + + """.encode() + + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [(b"content-type", b"text/html")], + } + ) + await send( + { + "type": "http.response.body", + "body": success_html, + } + ) + + async def _handle_health(self, send: Callable[..., Awaitable[None]]) -> None: + """Handle health check.""" + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [(b"content-type", b"application/json")], + } + ) + await send( + { + "type": "http.response.body", + "body": b'{"status": "ok"}', + } + ) + + async def _handle_404(self, send: Callable[..., Awaitable[None]]) -> None: + """Handle 404 responses.""" + await send( + { + "type": "http.response.start", + "status": 404, + "headers": [(b"content-type", b"application/json")], + } + ) + await send( + { + "type": "http.response.body", + "body": b'{"status": "not_found"}', + } + ) + + +class SlackAuthManager: + """Singleton Slack API authentication manager.""" + + _instance: Optional["SlackAuthManager"] = None + _initialized = False + + def __new__(cls) -> "SlackAuthManager": + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self) -> None: + if self._initialized: + return + + self._client: Optional[WebClient] = None + self._token_data: Optional[Dict[str, Any]] = None + self._authenticate() + self._initialized = True + + def _get_credentials(self) -> Tuple[str, str]: + """Get Slack client credentials from config.""" + config = get_current_config() + + if not config.slack_client_id or not config.slack_client_secret: + raise ValueError( + "Slack client credentials not configured. " + "Please set SLACK_CLIENT_ID and SLACK_CLIENT_SECRET environment variables." + ) + + return config.slack_client_id, config.slack_client_secret + + def _load_token(self, secrets_manager: SecretsManager) -> Optional[Dict[str, Any]]: + """Load existing Slack token from secrets manager.""" + token_data = secrets_manager.get_oauth_token("slack") + if not token_data: + logger.debug("No existing Slack token found") + return None + + return { + "access_token": token_data["access_token"], + "token_type": token_data.get("token_type", "Bearer"), + "scope": token_data.get("scopes", _DEFAULT_SCOPES), + "user_id": token_data.get("extra_data", {}).get("user_id"), + "team_id": token_data.get("extra_data", {}).get("team_id"), + "team_name": token_data.get("extra_data", {}).get("team_name"), + } + + def _store_token( + self, token_data: Dict[str, Any], secrets_manager: SecretsManager + ) -> None: + """Store Slack token in secrets manager.""" + # Prepare extra data + extra_data = {} + if token_data.get("user_id"): + extra_data["user_id"] = token_data["user_id"] + if token_data.get("team_id"): + extra_data["team_id"] = token_data["team_id"] + if token_data.get("team_name"): + extra_data["team_name"] = token_data["team_name"] + + # Store in secrets manager + secrets_manager.store_oauth_token( + service_name="slack", + access_token=token_data["access_token"], + refresh_token=None, # Slack doesn't use refresh tokens + token_type=token_data.get("token_type", "Bearer"), + expires_at=None, # Slack tokens don't expire + scopes=token_data.get("scope", _DEFAULT_SCOPES), + extra_data=extra_data if extra_data else None, + ) + + def _generate_auth_url( + self, client_id: str, user_scopes: List[str], redirect_uri: str + ) -> Tuple[str, str]: + """Generate Slack OAuth authorization URL.""" + # Generate a random state for security + state = secrets.token_urlsafe(32) + + # Create authorization URL + auth_url_generator = AuthorizeUrlGenerator( + client_id=client_id, + scopes=[], + user_scopes=user_scopes, + redirect_uri=redirect_uri, + ) + + auth_url = auth_url_generator.generate(state=state) + return auth_url, state + + def _exchange_code_for_token( + self, code: str, client_id: str, client_secret: str + ) -> Dict[str, Any]: + """Exchange authorization code for access token.""" + # Use WebClient to make OAuth v2 access call + oauth_client = WebClient() + + try: + slack_response = oauth_client.oauth_v2_access( + client_id=client_id, client_secret=client_secret, code=code + ) + + # Extract the actual data from SlackResponse object + response_data = slack_response.data + + # Handle the case where response.data is bytes instead of dict + if isinstance(response_data, bytes): + raise ValueError( + "Unexpected response format: received bytes instead of dict" + ) + + # Now we can safely treat response_data as Dict[str, Any] + response: Dict[str, Any] = response_data + + if not response.get("ok"): + raise ValueError(f"OAuth exchange failed: {response.get('error')}") + + # For user scopes, the access token might be in authed_user instead of top level + access_token = response.get("access_token") + if not access_token and "authed_user" in response: + authed_user = response["authed_user"] + if isinstance(authed_user, dict): + access_token = authed_user.get("access_token") + + if not access_token: + raise ValueError( + f"No access token found in response. Response keys: {list(response.keys())}" + ) + + # For user scopes, get the user scope from authed_user + user_scope = None + if "authed_user" in response and isinstance(response["authed_user"], dict): + authed_user = response["authed_user"] + if "scope" in authed_user: + user_scope = authed_user["scope"] + elif response.get("scope"): + user_scope = response.get("scope") + + scopes: List[str] = [] + if user_scope: + scopes = ( + user_scope.split(",") if isinstance(user_scope, str) else user_scope + ) + else: + scopes = _DEFAULT_SCOPES + + # Helper function to safely get nested dict values + def get_nested_value(data: Dict[str, Any], *keys: str) -> Any: + for key in keys: + if isinstance(data, dict) and key in data: + data = data[key] + else: + return None + return data + + return { + "access_token": access_token, + "token_type": response.get("token_type", "Bearer"), + "scope": scopes, + "user_id": get_nested_value(response, "authed_user", "id"), + "team_id": get_nested_value(response, "team", "id"), + "team_name": get_nested_value(response, "team", "name"), + } + except Exception as e: + raise ValueError(f"Failed to exchange code for token: {e}") + + def _run_oauth_flow(self) -> Dict[str, Any]: + """Run the OAuth flow to get a new token using OAuth server.""" + client_id, client_secret = self._get_credentials() + + # Create OAuth server + oauth_server = SlackOAuthServer() + + # Generate auth URL + auth_url, state = self._generate_auth_url( + client_id, _DEFAULT_SCOPES, oauth_server.redirect_uri + ) + + # Run OAuth flow with server + callback_result = oauth_server.run_oauth_flow(auth_url, state) + + # Exchange code for token + return self._exchange_code_for_token( + callback_result["code"], client_id, client_secret + ) + + def _test_token(self, token: str) -> bool: + """Test if a token is valid by making an API call.""" + try: + client = WebClient(token=token) + response = client.auth_test() + return response.get("ok", False) + except SlackApiError: + return False + + def _authenticate(self) -> None: + """Authenticate with Slack API.""" + config = get_current_config() + secrets_manager = get_secrets_manager(config.secrets_database_url) + + # Try to load existing token + self._token_data = self._load_token(secrets_manager) + + # Test token if it exists + if self._token_data and self._token_data.get("access_token"): + if self._test_token(self._token_data["access_token"]): + self._client = WebClient(token=self._token_data["access_token"]) + logger.info("Successfully authenticated with existing Slack token") + return + else: + logger.warning("Existing Slack token is invalid") + self._token_data = None + + # Run OAuth flow to get new token + try: + self._token_data = self._run_oauth_flow() + self._store_token(self._token_data, secrets_manager) + self._client = WebClient(token=self._token_data["access_token"]) + logger.info("Successfully authenticated with new Slack token") + except Exception as e: + raise RuntimeError(f"Failed to authenticate with Slack: {e}") + + def get_client(self) -> WebClient: + """Get authenticated Slack client.""" + if not self._client: + raise RuntimeError( + "Slack client not initialized. Authentication may have failed." + ) + return self._client + + def get_token_data(self) -> Optional[Dict[str, Any]]: + """Get current token data.""" + return self._token_data + + def refresh_token(self) -> None: + """Refresh the token (re-run OAuth since Slack doesn't have refresh tokens).""" + logger.info("Refreshing Slack token by re-running OAuth flow") + config = get_current_config() + secrets_manager = get_secrets_manager(config.secrets_database_url) + + try: + self._token_data = self._run_oauth_flow() + self._store_token(self._token_data, secrets_manager) + self._client = WebClient(token=self._token_data["access_token"]) + logger.info("Successfully refreshed Slack token") + except Exception as e: + raise RuntimeError(f"Failed to refresh Slack token: {e}") + + def revoke_token(self) -> None: + """Revoke the current token.""" + if not self._client or not self._token_data: + return + + try: + self._client.auth_revoke() + logger.info("Successfully revoked Slack token") + except SlackApiError as e: + logger.warning(f"Failed to revoke token: {e}") + finally: + self._client = None + self._token_data = None diff --git a/src/pragweb/slack/client.py b/src/pragweb/slack/client.py new file mode 100644 index 0000000..bcd192c --- /dev/null +++ b/src/pragweb/slack/client.py @@ -0,0 +1,342 @@ +"""High-level Slack API client that abstracts API specifics.""" + +import logging +from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple, cast + +from slack_sdk import WebClient +from slack_sdk.errors import SlackApiError +from slack_sdk.web.slack_response import SlackResponse + +from .auth import SlackAuthManager + +logger = logging.getLogger(__name__) + + +class SlackAPIClient: + """High-level client for Slack API interactions.""" + + def __init__(self, auth_manager: Optional[SlackAuthManager] = None): + self.auth_manager = auth_manager or SlackAuthManager() + self._client: Optional[WebClient] = None + + @property + def client(self) -> WebClient: + """Get the authenticated Slack client.""" + if self._client is None: + self._client = self.auth_manager.get_client() + return self._client + + # Channel Methods + def get_channel_info(self, channel_id: str) -> Dict[str, Any]: + """Get channel information by ID.""" + response: SlackResponse = self.client.conversations_info(channel=channel_id) + response_data = cast(Dict[str, Any], response.data) + + if response_data.get("ok"): + channel_data = response_data.get("channel") + if not isinstance(channel_data, dict): + raise ValueError("Invalid channel data received") + return cast(Dict[str, Any], channel_data) + else: + # Handle SlackApiError constructor properly + error_msg = f"Failed to get channel info: {response_data.get('error')}" + raise SlackApiError(error_msg, response) # type: ignore[no-untyped-call] + + def list_channels( + self, types: str = "public_channel,private_channel,mpim,im", limit: int = 1000 + ) -> List[Dict[str, Any]]: + """List only channels the authenticated user is a member of.""" + channels: List[Dict[str, Any]] = [] + cursor: Optional[str] = None + + while True: + response: SlackResponse = self.client.conversations_list( + types=types, + limit=min(limit, 1000), # API max is 1000 + cursor=cursor, + exclude_archived=True, # Don't include archived channels + ) + response_data = cast(Dict[str, Any], response.data) + + if response_data.get("ok"): + channels_data = response_data.get("channels", []) + if isinstance(channels_data, list): + # Filter to only channels the user is a member of + user_channels = [] + for channel in cast(List[Dict[str, Any]], channels_data): + # For public channels, check is_member field + if channel.get("is_channel"): # Public channel + if channel.get("is_member", False): + user_channels.append(channel) + else: + # Private channels, DMs, and group DMs should already be filtered by the API + # to only include channels the user is a member of + user_channels.append(channel) + + channels.extend(user_channels) + + # Get next cursor + metadata = response_data.get("response_metadata", {}) + if isinstance(metadata, dict): + cursor = metadata.get("next_cursor") + else: + cursor = None + + if not cursor or len(channels) >= limit: + break + else: + error_msg = f"Failed to list channels: {response_data.get('error')}" + raise SlackApiError(error_msg, response) # type: ignore[no-untyped-call] + + return channels[:limit] + + def get_channel_members(self, channel_id: str) -> List[str]: + """Get list of member IDs for a channel.""" + members: List[str] = [] + cursor: Optional[str] = None + + while True: + response: SlackResponse = self.client.conversations_members( + channel=channel_id, cursor=cursor + ) + response_data = cast(Dict[str, Any], response.data) + + if response_data.get("ok"): + members_data = response_data.get("members", []) + if isinstance(members_data, list): + members.extend(cast(List[str], members_data)) + + # Get next cursor + metadata = response_data.get("response_metadata", {}) + if isinstance(metadata, dict): + cursor = metadata.get("next_cursor") + else: + cursor = None + + if not cursor: + break + else: + error_msg = ( + f"Failed to get channel members: {response_data.get('error')}" + ) + raise SlackApiError(error_msg, response) # type: ignore[no-untyped-call] + return members + + # Message Methods + def get_conversation_history( + self, + channel_id: str, + oldest: Optional[str] = None, + latest: Optional[str] = None, + inclusive: bool = False, + limit: int = 100, + cursor: Optional[str] = None, + ) -> Tuple[List[Dict[str, Any]], Optional[str]]: + """Get conversation history with pagination.""" + # Call conversations_history with explicit parameters instead of **params + response: SlackResponse = self.client.conversations_history( + channel=channel_id, + limit=min(limit, 1000), # API max is 1000 + oldest=oldest, + latest=latest, + cursor=cursor, + inclusive=inclusive, + ) + response_data = cast(Dict[str, Any], response.data) + + if response_data.get("ok"): + messages_data = response_data.get("messages", []) + messages = ( + cast(List[Dict[str, Any]], messages_data) + if isinstance(messages_data, list) + else [] + ) + + # Get next cursor + metadata = response_data.get("response_metadata", {}) + next_cursor = None + if isinstance(metadata, dict): + next_cursor = metadata.get("next_cursor") + + return messages, next_cursor + else: + error_msg = ( + f"Failed to get conversation history: {response_data.get('error')}" + ) + raise SlackApiError(error_msg, response) # type: ignore[no-untyped-call] + + def get_thread_replies( + self, channel_id: str, thread_ts: str + ) -> List[Dict[str, Any]]: + """Get all replies in a thread.""" + response: SlackResponse = self.client.conversations_replies( + channel=channel_id, ts=thread_ts + ) + response_data = cast(Dict[str, Any], response.data) + + if response_data.get("ok"): + messages_data = response_data.get("messages", []) + return ( + cast(List[Dict[str, Any]], messages_data) + if isinstance(messages_data, list) + else [] + ) + else: + error_msg = f"Failed to get thread replies: {response_data.get('error')}" + raise SlackApiError(error_msg, response) # type: ignore[no-untyped-call] + + # Search Methods + def search_messages( + self, + query: str, + sort: str = "timestamp", + sort_dir: str = "desc", + count: int = 20, + page: int = 1, + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + """Search messages across all channels.""" + response: SlackResponse = self.client.search_messages( + query=query, sort=sort, sort_dir=sort_dir, count=count, page=page + ) + response_data = cast(Dict[str, Any], response.data) + + if response_data.get("ok"): + messages_section = response_data.get("messages", {}) + if isinstance(messages_section, dict): + messages_data = messages_section.get("matches", []) + pagination_data = messages_section.get("pagination", {}) + messages = ( + cast(List[Dict[str, Any]], messages_data) + if isinstance(messages_data, list) + else [] + ) + pagination = ( + cast(Dict[str, Any], pagination_data) + if isinstance(pagination_data, dict) + else {} + ) + return messages, pagination + else: + return [], {} + else: + error_msg = f"Failed to search messages: {response_data.get('error')}" + raise SlackApiError(error_msg, response) # type: ignore[no-untyped-call] + + def search_messages_in_channel( + self, + channel_id: str, + query: str = "", + oldest: Optional[str] = None, + latest: Optional[str] = None, + limit: int = 100, + ) -> List[Dict[str, Any]]: + """Search messages within a specific channel.""" + # Build search query + search_parts = [] + if query: + search_parts.append(query) + search_parts.append(f"in:<#{channel_id}>") + + if oldest: + # Convert timestamp to readable date for search + oldest_dt = datetime.fromtimestamp(float(oldest)) + search_parts.append(f"after:{oldest_dt.strftime('%Y-%m-%d')}") + if latest: + latest_dt = datetime.fromtimestamp(float(latest)) + search_parts.append(f"before:{latest_dt.strftime('%Y-%m-%d')}") + + search_query = " ".join(search_parts) + + messages, _ = self.search_messages(search_query, count=limit) + return messages + + # User Methods + def get_user_info(self, user_id: str) -> Dict[str, Any]: + """Get user information by ID.""" + response: SlackResponse = self.client.users_info(user=user_id) + response_data = cast(Dict[str, Any], response.data) + + if response_data.get("ok"): + user_data = cast(Dict[str, Any], response_data.get("user")) + + if not isinstance(user_data, dict): + raise ValueError("Invalid user data received") + return user_data + else: + error_msg = f"Failed to get user info: {response_data.get('error')}" + raise SlackApiError(error_msg, response) # type: ignore[no-untyped-call] + + def list_users(self, limit: int = 1000) -> List[Dict[str, Any]]: + """List all users in the workspace.""" + users: List[Dict[str, Any]] = [] + cursor: Optional[str] = None + + while True: + response: SlackResponse = self.client.users_list( + limit=min(limit, 1000), cursor=cursor # API max is 1000 + ) + response_data = cast(Dict[str, Any], response.data) + + if response_data.get("ok"): + members_data = response_data.get("members", []) + if isinstance(members_data, list): + users.extend(cast(List[Dict[str, Any]], members_data)) + + # Get next cursor + metadata = response_data.get("response_metadata", {}) + if isinstance(metadata, dict): + cursor = metadata.get("next_cursor") + else: + cursor = None + + if not cursor or len(users) >= limit: + break + else: + error_msg = f"Failed to list users: {response_data.get('error')}" + raise SlackApiError(error_msg, response) # type: ignore[no-untyped-call] + + return users[:limit] + + def lookup_user_by_email(self, email: str) -> Optional[Dict[str, Any]]: + """Look up user by email address.""" + response: SlackResponse = self.client.users_lookupByEmail(email=email) + response_data = cast(Dict[str, Any], response.data) + + if response_data.get("ok"): + user_data = response_data.get("user") + if isinstance(user_data, dict): + return cast(Dict[str, Any], user_data) + return None + else: + # Email not found is not an error condition + if response_data.get("error") == "users_not_found": + return None + error_msg = f"Failed to lookup user by email: {response_data.get('error')}" + raise SlackApiError(error_msg, response) # type: ignore[no-untyped-call] + + # Utility Methods + def test_auth(self) -> Dict[str, Any]: + """Test authentication and return auth info.""" + response: SlackResponse = self.client.auth_test() + response_data = cast(Dict[str, Any], response.data) + + if response_data.get("ok"): + return response_data + else: + error_msg = f"Auth test failed: {response_data.get('error')}" + raise SlackApiError(error_msg, response) # type: ignore[no-untyped-call] + + def get_team_info(self) -> Dict[str, Any]: + """Get team/workspace information.""" + response: SlackResponse = self.client.team_info() + response_data = cast(Dict[str, Any], response.data) + + if response_data.get("ok"): + team_data = response_data.get("team") + if not isinstance(team_data, dict): + raise ValueError("Invalid team data received") + return cast(Dict[str, Any], team_data) + else: + error_msg = f"Failed to get team info: {response_data.get('error')}" + raise SlackApiError(error_msg, response) # type: ignore[no-untyped-call] diff --git a/src/pragweb/slack/ingestion.py b/src/pragweb/slack/ingestion.py new file mode 100644 index 0000000..8563f18 --- /dev/null +++ b/src/pragweb/slack/ingestion.py @@ -0,0 +1,354 @@ +"""Slack ingestion service for bulk data operations and channel initialization.""" + +import logging +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any, Dict, List + +from praga_core.types import PageURI + +from .page import SlackChannelListPage, SlackConversationPage +from .utils import SlackParser + +if TYPE_CHECKING: + from .service import SlackService + +logger = logging.getLogger(__name__) + + +class SlackIngestionService: + """Sidecar service for Slack bulk data ingestion and channel operations.""" + + def __init__(self, slack_service: "SlackService"): + """Initialize with reference to main slack service.""" + self.slack_service = slack_service + self.api_client = slack_service.api_client + self.context = slack_service.context + self.page_cache = slack_service.page_cache + self.parser = SlackParser() + + def initialize_channel_data(self) -> None: + """Initialize channel data by ingesting all channels.""" + try: + logger.info("Initializing Slack channel data...") + + # Get workspace ID from auth test + auth_info = self.api_client.test_auth() + workspace_id = auth_info.get("team_id", "unknown") + + # Check if channel list page already exists + uri = PageURI( + root=self.context.root, + type="slack_channel_list", + id=workspace_id, + version=1, + ) + existing_page = self.page_cache.get_page(SlackChannelListPage, uri) + + if existing_page: + logger.info("Channel data already exists in cache") + return + + # Create new channel list page + self.ingest_all_channels() + logger.info("Channel data initialization complete") + except Exception as e: + logger.warning(f"Failed to initialize channel data: {e}") + raise + + def ingest_all_channels(self) -> int: + """Ingest all workspace channels for reference.""" + # Get workspace info from auth test (no team scope required) + auth_info = self.api_client.test_auth() + workspace_id = auth_info.get("team_id", "unknown") + + # Check if channel list page already exists + uri = PageURI( + root=self.context.root, + type="slack_channel_list", + id=workspace_id, + version=1, + ) + existing_page = self.page_cache.get_page(SlackChannelListPage, uri) + + if existing_page: + logger.info( + f"Channel list already cached with {existing_page.total_channels} channels" + ) + return existing_page.total_channels + + # Create the channel list page (this will fetch all channels with pagination) + channel_list_page = self.create_channel_list_page(workspace_id) + + logger.info(f"Ingested {channel_list_page.total_channels} channels") + return channel_list_page.total_channels + + def create_channel_list_page( + self, workspace_id: str, refresh: bool = False + ) -> SlackChannelListPage: + """Create a new channel list page with all workspace channels. + + Args: + workspace_id: Slack workspace/team ID + refresh: If True, force refresh even if cached page exists + + Returns: + SlackChannelListPage with all workspace channels + + Note: + This method automatically handles pagination through the API client's list_channels() method, + which follows cursor pagination to retrieve all channels in the workspace. + """ + # Check for existing page unless refresh is requested + if not refresh: + uri = PageURI( + root=self.context.root, + type="slack_channel_list", + id=workspace_id, + version=1, + ) + existing_page = self.page_cache.get_page(SlackChannelListPage, uri) + if existing_page: + # Check if data is stale (older than 1 hour) + if ( + datetime.now(timezone.utc) - existing_page.last_updated + ).total_seconds() < 3600: + logger.info("Using cached channel list page (fresh)") + return existing_page + else: + logger.info("Cached channel list page is stale, refreshing...") + + # Get workspace info from auth test (no team scope required) + auth_info = self.api_client.test_auth() + workspace_name = auth_info.get("team", "Unknown Workspace") + workspace_id = auth_info.get("team_id", workspace_id) + + # Get all channels with automatic pagination handling + logger.info("Fetching all workspace channels...") + all_channels = self.api_client.list_channels() + + # Count channel types and extract data + public_count = 0 + private_count = 0 + channel_data = [] + + for channel in all_channels: + channel_type = self.parser.determine_channel_type(channel) + if channel_type == "public_channel": + public_count += 1 + elif channel_type == "private_channel": + private_count += 1 + + # Extract relevant channel info + channel_info = { + "id": channel.get("id"), + "name": channel.get("name"), + "type": channel_type, + "topic": channel.get("topic", {}).get("value"), + "purpose": channel.get("purpose", {}).get("value"), + "member_count": channel.get("num_members", 0), + "is_archived": channel.get("is_archived", False), + } + channel_data.append(channel_info) + + # Create URI + uri = PageURI( + root=self.context.root, + type="slack_channel_list", + id=workspace_id, + version=1, + ) + + # Create page + channel_list_page = SlackChannelListPage( + uri=uri, + workspace_id=workspace_id, + workspace_name=workspace_name, + total_channels=len(all_channels), + public_channels=public_count, + private_channels=private_count, + channels=channel_data, + last_updated=datetime.now(timezone.utc), + ) + + # Cache the page + self.page_cache.store_page(channel_list_page) + + logger.info( + f"Created channel list page with {len(all_channels)} channels " + f"({public_count} public, {private_count} private)" + ) + return channel_list_page + + def refresh_channel_list(self) -> SlackChannelListPage: + """Force refresh the channel list from Slack API.""" + # Get workspace info + auth_info = self.api_client.test_auth() + workspace_id = auth_info.get("team_id", "unknown") + + # Create fresh channel list page + return self.create_channel_list_page(workspace_id, refresh=True) + + def ingest_channel(self, channel_id: str) -> int: + """Ingest all messages from a channel and create conversation pages.""" + logger.info(f"Starting ingestion of channel {channel_id}") + + # Get channel info using the main service method + channel_page = self.slack_service.get_channel_page(channel_id) + channel_name = channel_page.name + channel_type = channel_page.channel_type + + # Fetch all messages + all_messages = [] + cursor = None + + while True: + messages, next_cursor = self.api_client.get_conversation_history( + channel_id=channel_id, limit=1000, cursor=cursor + ) + + all_messages.extend(messages) + + if not next_cursor: + break + cursor = next_cursor + + logger.info(f"Fetched {len(all_messages)} messages from channel {channel_id}") + + # Chunk messages into conversations + conversation_pages = self._chunk_messages_by_time( + all_messages, channel_id, channel_name, channel_type + ) + + # Store conversation pages in cache + stored_count = 0 + for page in conversation_pages: + if self.page_cache.store_page(page): + stored_count += 1 + + logger.info( + f"Created and stored {stored_count} conversation pages for channel {channel_id}" + ) + return stored_count + + def _chunk_messages_by_time( + self, + messages: List[Dict[str, Any]], + channel_id: str, + channel_name: str, + channel_type: str, + max_chunk_size: int = 4000, + ) -> List[SlackConversationPage]: + """Chunk messages into conversation pages by time and content size.""" + if not messages: + return [] + + # Sort messages by timestamp + sorted_messages = sorted(messages, key=lambda m: float(m.get("ts", "0"))) + + conversation_pages: List[SlackConversationPage] = [] + current_chunk: List[Dict[str, Any]] = [] + current_size = 0 + + for message in sorted_messages: + message_text = message.get("text", "") + message_size = len(message_text) + + # If adding this message would exceed chunk size, create a new chunk + if current_chunk and (current_size + message_size > max_chunk_size): + # Create conversation page for current chunk + if current_chunk: + page = self._create_conversation_page( + current_chunk, + channel_id, + channel_name, + channel_type, + len(conversation_pages), + ) + conversation_pages.append(page) + + # Start new chunk + current_chunk = [message] + current_size = message_size + else: + current_chunk.append(message) + current_size += message_size + + # Handle remaining messages + if current_chunk: + page = self._create_conversation_page( + current_chunk, + channel_id, + channel_name, + channel_type, + len(conversation_pages), + ) + conversation_pages.append(page) + + return conversation_pages + + def _create_conversation_page( + self, + messages: List[Dict[str, Any]], + channel_id: str, + channel_name: str, + channel_type: str, + chunk_index: int, + ) -> SlackConversationPage: + """Create a SlackConversationPage from a chunk of messages.""" + if not messages: + raise ValueError("Cannot create conversation page from empty messages") + + # Get time range + timestamps = [float(msg.get("ts", "0")) for msg in messages] + start_time = datetime.fromtimestamp(min(timestamps), tz=timezone.utc) + end_time = datetime.fromtimestamp(max(timestamps), tz=timezone.utc) + + # Get participants - need to get display names + user_ids = list(set(msg.get("user", "") for msg in messages)) + participants = [] + for user_id in user_ids: + if user_id: + user_page = self.slack_service.get_user_page(user_id) + display_name = ( + user_page.display_name or user_page.real_name or user_page.name + ) + participants.append(display_name) + + # Format message content using parser + messages_content = self.parser.format_messages_for_content( + messages, + lambda user_id: self.slack_service.get_user_page(user_id).display_name + or self.slack_service.get_user_page(user_id).real_name + or self.slack_service.get_user_page(user_id).name, + ) + + # Create conversation ID + conversation_id = f"{channel_id}_{int(min(timestamps))}_{chunk_index}" + + # Create permalink - use first message timestamp + first_ts = messages[0].get("ts", "") + permalink = ( + f"https://slack.com/app_redirect?channel={channel_id}&message_ts={first_ts}" + ) + + # Create URI + uri = PageURI( + root=self.context.root, + type="slack_conversation", + id=conversation_id, + version=1, + ) + + return SlackConversationPage( + uri=uri, + conversation_id=conversation_id, + channel_id=channel_id, + channel_name=channel_name, + channel_type=channel_type, + start_time=start_time, + end_time=end_time, + message_count=len(messages), + participants=participants, + messages_content=messages_content, + permalink=permalink, + ) diff --git a/src/pragweb/slack/page.py b/src/pragweb/slack/page.py new file mode 100644 index 0000000..1b43281 --- /dev/null +++ b/src/pragweb/slack/page.py @@ -0,0 +1,181 @@ +"""Slack page definitions.""" + +from datetime import datetime +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field, computed_field + +from praga_core.types import Page, PageURI + + +class SlackMessageSummary(BaseModel): + """A compressed representation of a Slack message for use in thread/conversation pages.""" + + display_name: str = Field( + description="Display name of the user who sent the message" + ) + text: str = Field(description="Message text content") + timestamp: datetime = Field(description="Message timestamp") + + +class SlackConversationPage(Page): + """A chunk of Slack conversation messages grouped by temporal proximity.""" + + conversation_id: str = Field( + description="Unique identifier for this conversation chunk", exclude=True + ) + channel_id: str = Field( + description="Channel ID where conversation occurred", exclude=True + ) + channel_name: Optional[str] = Field(description="Channel name for display") + channel_type: str = Field( + description="Type: public_channel, private_channel, im, mpim" + ) + start_time: datetime = Field(description="Start time of conversation chunk") + end_time: datetime = Field(description="End time of conversation chunk") + message_count: int = Field(description="Number of messages in this chunk") + participants: List[str] = Field( + description="List of user display names who participated" + ) + messages_content: str = Field( + description="Combined formatted content of all messages" + ) + permalink: str = Field(description="Slack permalink to conversation") + + @computed_field + def next_conversation_uri(self) -> Optional[PageURI]: + """URI to next conversation chunk if it exists.""" + # This will be populated by the service based on temporal ordering + return None + + @computed_field + def prev_conversation_uri(self) -> Optional[PageURI]: + """URI to previous conversation chunk if it exists.""" + # This will be populated by the service based on temporal ordering + return None + + @computed_field + def related_threads(self) -> List[PageURI]: + """URIs to related thread pages in this conversation.""" + # This will be populated by the service when threads are detected + return [] + + +class SlackThreadPage(Page): + """A Slack thread containing all messages within a specific thread.""" + + thread_ts: str = Field(description="Thread timestamp identifier", exclude=True) + channel_id: str = Field(description="Channel ID where thread exists", exclude=True) + channel_name: Optional[str] = Field(description="Channel name for display") + parent_message: str = Field(description="Parent message that started the thread") + messages: List[SlackMessageSummary] = Field( + description="All messages in the thread" + ) + message_count: int = Field(description="Total number of messages in thread") + participants: List[str] = Field( + description="List of user display names who participated" + ) + created_at: datetime = Field(description="When the thread was created") + last_reply_at: Optional[datetime] = Field(description="When last reply was posted") + permalink: str = Field(description="Slack permalink to thread") + + @property + def thread_messages(self) -> str: + """Formatted string of all thread messages for content search.""" + formatted_messages = [] + for msg in self.messages: + formatted_messages.append(f"{msg.display_name}: {msg.text}") + return "\n".join(formatted_messages) + + +class SlackChannelPage(Page): + """A Slack channel with metadata and recent activity summary.""" + + channel_id: str = Field(description="Slack channel ID", exclude=True) + name: str = Field(description="Channel name") + channel_type: str = Field( + description="Type: public_channel, private_channel, im, mpim" + ) + topic: Optional[str] = Field(description="Channel topic") + purpose: Optional[str] = Field(description="Channel purpose") + member_count: int = Field(description="Number of members in channel") + created: datetime = Field(description="When channel was created") + is_archived: bool = Field(description="Whether channel is archived") + last_activity: Optional[datetime] = Field(description="Last message timestamp") + message_urls: List[str] = Field( + default=[], + description="List of recent message URLs in this channel", + exclude=True, + ) + permalink: str = Field(description="Slack permalink to channel") + + +class SlackUserPage(Page): + """A Slack user profile with information.""" + + user_id: str = Field(description="Slack user ID", exclude=True) + name: str = Field(description="Username") + real_name: Optional[str] = Field(description="Real name") + display_name: Optional[str] = Field(description="Display name") + email: Optional[str] = Field(description="Email address") + title: Optional[str] = Field(description="Job title") + is_bot: bool = Field(description="Whether this is a bot user") + is_admin: bool = Field(description="Whether user is admin") + status_text: Optional[str] = Field(description="Status message") + status_emoji: Optional[str] = Field(description="Status emoji") + last_updated: datetime = Field(description="When user info was last updated") + + +class SlackMessagePage(Page): + """A single Slack message with context and links to related conversation.""" + + message_ts: str = Field(description="Message timestamp identifier", exclude=True) + channel_id: str = Field(description="Channel ID where message exists", exclude=True) + channel_name: Optional[str] = Field(description="Channel name for display") + channel_type: str = Field( + description="Type: public_channel, private_channel, im, mpim" + ) + user_id: str = Field(description="User ID who sent the message", exclude=True) + display_name: str = Field(description="Display name of message sender") + text_content: str = Field(description="Message text content") + timestamp: datetime = Field(description="When message was sent") + thread_ts: Optional[str] = Field( + description="Thread timestamp if message is part of a thread" + ) + next_message_uri: Optional[PageURI] = Field( + default=None, description="URI to the next message in chronological order" + ) + previous_message_uri: Optional[PageURI] = Field( + default=None, + description="URI to the previous message in reverse chronological order", + ) + permalink: str = Field(description="Slack permalink to message") + + @computed_field + def conversation_uri(self) -> Optional[PageURI]: + """URI to the full conversation containing this message.""" + # This will be populated by the service + return None + + @computed_field + def thread_uri(self) -> Optional[PageURI]: + """URI to thread page if this message is part of a thread.""" + if self.thread_ts: + # Use underscore separator since colons aren't allowed in PageURI IDs + thread_id = f"{self.channel_id}_{self.thread_ts}" + return PageURI( + root=self.uri.root, type="slack_thread", id=thread_id, version=1 + ) + return None + + +class SlackChannelListPage(Page): + """A page containing all workspace channels for reference and search.""" + + workspace_id: str = Field(description="Slack workspace/team ID", exclude=True) + workspace_name: str = Field(description="Workspace name") + total_channels: int = Field(description="Total number of channels") + public_channels: int = Field(description="Number of public channels") + private_channels: int = Field(description="Number of private channels") + channels: List[Dict[str, Any]] = Field(description="List of channel metadata") + last_updated: datetime = Field(description="When channel list was last updated") diff --git a/src/pragweb/slack/service.py b/src/pragweb/slack/service.py new file mode 100644 index 0000000..a018710 --- /dev/null +++ b/src/pragweb/slack/service.py @@ -0,0 +1,760 @@ +"""Slack service for handling API interactions and page creation.""" + +import logging +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, List, Optional, Tuple, cast + +from praga_core.agents import PaginatedResponse, RetrieverToolkit, tool +from praga_core.types import PageURI +from pragweb.toolkit_service import ToolkitService + +from .client import SlackAPIClient +from .ingestion import SlackIngestionService +from .page import ( + SlackChannelListPage, + SlackChannelPage, + SlackConversationPage, + SlackMessagePage, + SlackMessageSummary, + SlackThreadPage, + SlackUserPage, +) +from .utils import SlackParser + +logger = logging.getLogger(__name__) + + +class SlackService(ToolkitService): + """Service for Slack API interactions and page creation.""" + + def __init__(self, api_client: Optional[SlackAPIClient] = None) -> None: + super().__init__() + self.api_client = api_client or SlackAPIClient() + self.parser = SlackParser() + + # Register page types with the page cache + self.page_cache.register_page_type(SlackConversationPage) + self.page_cache.register_page_type(SlackThreadPage) + self.page_cache.register_page_type(SlackChannelPage) + self.page_cache.register_page_type(SlackUserPage) + self.page_cache.register_page_type(SlackChannelListPage) + self.page_cache.register_page_type(SlackMessagePage) + + # Register handlers using decorators + self._register_handlers() + + # Initialize ingestion service + self.ingestion = SlackIngestionService(self) + + logger.info("Slack service initialized and handlers registered") + + def _register_handlers(self) -> None: + """Register handlers with context using decorators.""" + + @self.context.handler("slack_conversation") + def handle_conversation(conversation_id: str) -> SlackConversationPage: + return self.get_conversation_page(conversation_id) + + @self.context.handler("slack_thread") + def handle_thread(thread_id: str) -> SlackThreadPage: + return self.get_thread_page(thread_id) + + @self.context.handler("slack_channel") + def handle_channel(channel_id: str) -> SlackChannelPage: + return self.get_channel_page(channel_id) + + @self.context.handler("slack_user") + def handle_user(user_id: str) -> SlackUserPage: + return self.get_user_page(user_id) + + @self.context.handler("slack_channel_list") + def handle_channel_list(workspace_id: str) -> SlackChannelListPage: + return self.get_channel_list_page(workspace_id) + + @self.context.handler("slack_message") + def handle_message(message_id: str) -> SlackMessagePage: + logger.info(f"Message handler called for message_id: {message_id}") + return self.get_message_page(message_id) + + def _get_existing_user_page(self, user_id: str) -> Optional[SlackUserPage]: + """Get existing user page from cache.""" + uri = PageURI(root=self.context.root, type="slack_user", id=user_id) + return self.page_cache.get_page(SlackUserPage, uri) + + def _get_existing_channel_page(self, channel_id: str) -> Optional[SlackChannelPage]: + """Get existing channel page from cache.""" + uri = PageURI(root=self.context.root, type="slack_channel", id=channel_id) + return self.page_cache.get_page(SlackChannelPage, uri) + + def get_user_display_name(self, user_id: str) -> str: + """Get user display name for UI, creating user page if needed.""" + if not user_id: + return "unknown" + + user_page = self.get_user_page(user_id) + return user_page.display_name or user_page.real_name or user_page.name + + def get_conversation_page(self, conversation_id: str) -> SlackConversationPage: + """Get conversation page, creating if needed.""" + # Try to get from cache first + uri = PageURI( + root=self.context.root, type="slack_conversation", id=conversation_id + ) + existing_page = self.page_cache.get_page(SlackConversationPage, uri) + + if existing_page: + return existing_page + + # Create conversation from channel ID (fetch recent messages) + return self.create_conversation_page(conversation_id) + + def create_conversation_page(self, conversation_id: str) -> SlackConversationPage: + """Create conversation page from channel ID by fetching recent messages.""" + # Assume conversation_id is a channel_id + channel_id = conversation_id + + # Get channel page for info + channel_page = self.get_channel_page(channel_id) + channel_name = channel_page.name + channel_type = channel_page.channel_type + + # Fetch recent messages + messages, _ = self.api_client.get_conversation_history( + channel_id=channel_id, limit=50 + ) + + if not messages: + raise ValueError(f"No messages found in channel {channel_id}") + + # Get time range + timestamps = [float(msg.get("ts", "0")) for msg in messages] + start_time = datetime.fromtimestamp(min(timestamps), tz=timezone.utc) + end_time = datetime.fromtimestamp(max(timestamps), tz=timezone.utc) + + # Get participants + user_ids = list(set(msg.get("user", "") for msg in messages)) + participants = [] + for user_id in user_ids: + if user_id: + display_name = self.get_user_display_name(user_id) + participants.append(display_name) + + # Format message content using parser + messages_content = self.parser.format_messages_for_content( + messages, self.get_user_display_name + ) + + # Create permalink - use first message timestamp + first_ts = messages[0].get("ts", "") + permalink = ( + f"https://slack.com/app_redirect?channel={channel_id}&message_ts={first_ts}" + ) + + # Create URI + uri = PageURI( + root=self.context.root, + type="slack_conversation", + id=conversation_id, + version=1, + ) + + conversation_page = SlackConversationPage( + uri=uri, + conversation_id=conversation_id, + channel_id=channel_id, + channel_name=channel_name, + channel_type=channel_type, + start_time=start_time, + end_time=end_time, + message_count=len(messages), + participants=participants, + messages_content=messages_content, + permalink=permalink, + ) + + # Store in cache + self.page_cache.store_page(conversation_page) + return conversation_page + + def get_message_page(self, message_id: str) -> SlackMessagePage: + """Get message page, creating if needed.""" + logger.info(f"Getting message page for message_id: {message_id}") + + # Try to get from cache first + uri = PageURI(root=self.context.root, type="slack_message", id=message_id) + existing_page = self.page_cache.get_page(SlackMessagePage, uri) + + if existing_page: + logger.info(f"Found cached message page for {message_id}") + return existing_page + + logger.info(f"No cached page found, creating new message page for {message_id}") + # Create message page on-demand + return self.create_message_page(message_id) + + def _fetch_single_message( + self, channel_id: str, message_ts: str, *, before: bool = False + ) -> Optional[Dict[str, Any]]: + """Fetch a single message from the channel. + + Args: + channel_id: The channel ID to fetch from + message_ts: The timestamp to fetch around + before: If True, fetch message before timestamp, if False fetch after + + Returns: + The message if found and its timestamp is correctly ordered relative to message_ts, + None otherwise. + """ + current_ts = float(message_ts) + + if before: + messages, _ = self.api_client.get_conversation_history( + channel_id=channel_id, latest=message_ts, inclusive=False, limit=1 + ) + if messages: + msg = messages[0] + # Verify the message is actually before + if float(msg["ts"]) < current_ts: + return cast(Dict[str, Any], msg) + else: + messages, _ = self.api_client.get_conversation_history( + channel_id=channel_id, oldest=message_ts, inclusive=False, limit=1 + ) + if messages: + msg = messages[0] + # Verify the message is actually after + if float(msg["ts"]) > current_ts: + return cast(Dict[str, Any], msg) + + return None + + def _create_message_uri(self, channel_id: str, message_ts: str) -> PageURI: + """Create a PageURI for a message.""" + message_id = self.parser.encode_message_id(channel_id, message_ts) + return PageURI(root=self.context.root, type="slack_message", id=message_id) + + def create_message_page(self, message_id: str) -> SlackMessagePage: + """Create a message page from message ID (format: channel_id_message_ts).""" + logger.info(f"Creating message page for message ID: {message_id}") + + # Parse message ID and get channel info + channel_id, message_ts = self.parser.decode_message_id(message_id) + channel_page = self.get_channel_page(channel_id) + + # Fetch target message + messages, _ = self.api_client.get_conversation_history( + channel_id=channel_id, oldest=message_ts, inclusive=True, limit=1 + ) + if not messages: + raise RuntimeError(f"Unable to find message: {message_id}") + target_message = messages[0] + + # Fetch adjacent messages + prev_message = self._fetch_single_message(channel_id, message_ts, before=True) + next_message = self._fetch_single_message(channel_id, message_ts, before=False) + + # Create navigation URIs + prev_uri = ( + self._create_message_uri(channel_id, prev_message["ts"]) + if prev_message + else None + ) + next_uri = ( + self._create_message_uri(channel_id, next_message["ts"]) + if next_message + else None + ) + + # Parse message metadata + user_id = target_message.get("user", "") + display_name = self.get_user_display_name(user_id) + timestamp = datetime.fromtimestamp(float(target_message["ts"]), tz=timezone.utc) + thread_ts = target_message.get("thread_ts") + + # Create page URI and permalink + uri = self._create_message_uri(channel_id, message_ts) + permalink = f"https://slack.com/app_redirect?channel={channel_id}&message_ts={message_ts}" + + # Create and return the page + page = SlackMessagePage( + uri=uri, + message_ts=message_ts, + channel_id=channel_id, + channel_name=channel_page.name, + channel_type=channel_page.channel_type, + user_id=user_id, + display_name=display_name, + text_content=target_message.get("text", ""), + timestamp=timestamp, + thread_ts=thread_ts, + next_message_uri=next_uri, + previous_message_uri=prev_uri, + permalink=permalink, + ) + + # Store in cache + self.page_cache.store_page(page) + + return page + + def get_thread_page(self, thread_id: str) -> SlackThreadPage: + """Get thread page, creating if needed.""" + # Try to get from cache first + uri = PageURI(root=self.context.root, type="slack_thread", id=thread_id) + existing_page = self.page_cache.get_page(SlackThreadPage, uri) + + if existing_page: + return existing_page + + # Create from API if not cached + return self.create_thread_page(thread_id) + + def create_thread_page(self, thread_id: str) -> SlackThreadPage: + """Create a SlackThreadPage from thread ID (format: channel_id_thread_ts).""" + channel_id, thread_ts = self.parser.decode_thread_id(thread_id) + + # Get thread messages from API + messages = self.api_client.get_thread_replies(channel_id, thread_ts) + + if not messages: + raise ValueError(f"Thread {thread_id} contains no messages") + + # Get channel page for info + channel_page = self.get_channel_page(channel_id) + channel_name = channel_page.name + + # Parse messages into SlackMessageSummary objects + message_summaries = [] + participants = set() + + for msg in messages: + user_id = msg.get("user", "") + user_name = self.get_user_display_name(user_id) + participants.add(user_name) + + timestamp_str = msg.get("ts", "") + timestamp = datetime.fromtimestamp(float(timestamp_str), tz=timezone.utc) + + summary = SlackMessageSummary( + display_name=user_name, text=msg.get("text", ""), timestamp=timestamp + ) + message_summaries.append(summary) + + # Get parent message (first message) + parent_message = messages[0].get("text", "") if messages else "" + + # Get time info + created_at = ( + message_summaries[0].timestamp + if message_summaries + else datetime.now(tz=timezone.utc) + ) + last_reply_at = ( + message_summaries[-1].timestamp if len(message_summaries) > 1 else None + ) + + # Create permalink + permalink = f"https://slack.com/app_redirect?channel={channel_id}&message_ts={thread_ts}" + + # Create URI + uri = PageURI( + root=self.context.root, type="slack_thread", id=thread_id, version=1 + ) + + page = SlackThreadPage( + uri=uri, + thread_ts=thread_ts, + channel_id=channel_id, + channel_name=channel_name, + parent_message=parent_message, + messages=message_summaries, + message_count=len(message_summaries), + participants=list(participants), + created_at=created_at, + last_reply_at=last_reply_at, + permalink=permalink, + ) + + # Store in cache + self.page_cache.store_page(page) + return page + + def get_channel_page(self, channel_id: str) -> SlackChannelPage: + """Get channel page, creating if needed.""" + # Try to get from cache first + uri = PageURI(root=self.context.root, type="slack_channel", id=channel_id) + existing_page = self.page_cache.get_page(SlackChannelPage, uri) + + if existing_page: + return existing_page + + # Create from API if not cached + return self.create_channel_page(channel_id) + + def create_channel_page(self, channel_id: str) -> SlackChannelPage: + """Create a SlackChannelPage from channel ID.""" + channel_info = self.api_client.get_channel_info(channel_id) + + # Parse channel data + name = channel_info.get("name", channel_id) + channel_type = self.parser.determine_channel_type(channel_info) + topic = channel_info.get("topic", {}).get("value") + purpose = channel_info.get("purpose", {}).get("value") + + # Get member count + try: + members = self.api_client.get_channel_members(channel_id) + member_count = len(members) + except Exception as e: + logger.warning(f"Failed to get member count for {channel_id}: {e}") + member_count = 0 + + # Parse timestamps + created_ts = channel_info.get("created", 0) + created = ( + datetime.fromtimestamp(created_ts, tz=timezone.utc) + if created_ts + else datetime.now(tz=timezone.utc) + ) + + is_archived = channel_info.get("is_archived", False) + + # Get last activity (would need to fetch recent messages) + last_activity = None # TODO: Implement if needed + + # Create permalink + permalink = f"https://slack.com/app_redirect?channel={channel_id}" + + # Create URI + uri = PageURI( + root=self.context.root, type="slack_channel", id=channel_id, version=1 + ) + + page = SlackChannelPage( + uri=uri, + channel_id=channel_id, + name=name, + channel_type=channel_type, + topic=topic, + purpose=purpose, + member_count=member_count, + created=created, + is_archived=is_archived, + last_activity=last_activity, + message_urls=[], # Empty by default, can be populated on demand + permalink=permalink, + ) + + # Store in cache + self.page_cache.store_page(page) + return page + + def get_user_page(self, user_id: str) -> SlackUserPage: + """Get user page, creating if needed.""" + # Try to get from cache first + uri = PageURI(root=self.context.root, type="slack_user", id=user_id) + existing_page = self.page_cache.get_page(SlackUserPage, uri) + + if existing_page: + return existing_page + + # Create from API if not cached + return self.create_user_page(user_id) + + def create_user_page(self, user_id: str) -> SlackUserPage: + """Create a SlackUserPage from user ID.""" + user_info = self.api_client.get_user_info(user_id) + profile = user_info.get("profile", {}) + + # Parse user data + name = user_info.get("name", user_id) + real_name = user_info.get("real_name") + display_name = profile.get("display_name") + email = profile.get("email") + title = profile.get("title") + + is_bot = user_info.get("is_bot", False) + is_admin = user_info.get("is_admin", False) + + status_text = profile.get("status_text") + status_emoji = profile.get("status_emoji") + + last_updated = datetime.now(tz=timezone.utc) + + # Create URI + uri = PageURI(root=self.context.root, type="slack_user", id=user_id, version=1) + + page = SlackUserPage( + uri=uri, + user_id=user_id, + name=name, + real_name=real_name, + display_name=display_name, + email=email, + title=title, + is_bot=is_bot, + is_admin=is_admin, + status_text=status_text, + status_emoji=status_emoji, + last_updated=last_updated, + ) + + # Store in cache + self.page_cache.store_page(page) + return page + + def get_channel_list_page( + self, workspace_id: str, refresh: bool = False + ) -> SlackChannelListPage: + """Get or create channel list page using ingestion service.""" + return self.ingestion.create_channel_list_page(workspace_id, refresh) + + def search_messages( + self, query: str, page_token: Optional[str] = None, page_size: int = 20 + ) -> Tuple[List[PageURI], Optional[str]]: + """Search conversations and return URIs with pagination.""" + try: + # Use Slack search API + page_num = 1 + if page_token: + try: + page_num = int(page_token) + except ValueError: + page_num = 1 + + messages, pagination = self.api_client.search_messages( + query=query, count=page_size, page=page_num + ) + + # Convert to PageURIs (return message pages for individual messages) + uris = [] + processed_threads = set() + + logger.info(f"Processing {len(messages)} search result messages") + + for i, msg in enumerate(messages): + channel_id = msg.get("channel", {}).get("id", "") + ts = msg.get("ts", "") + thread_ts = msg.get("thread_ts") + text = msg.get("text", "") + user = msg.get("user", "") + + logger.info( + f"Message {i+1}: channel_id={channel_id}, ts={ts}, user={user}" + ) + logger.info( + f"Message {i+1} text: {text[:100]}{'...' if len(text) > 100 else ''}" + ) + + if thread_ts: + logger.info(f"Message {i+1} is part of thread: {thread_ts}") + + if thread_ts and f"{channel_id}_{thread_ts}" not in processed_threads: + # This is part of a thread - return the thread page + thread_id = self.parser.encode_thread_id(channel_id, thread_ts) + logger.info(f"Creating thread URI for {thread_id}") + uri = PageURI( + root=self.context.root, type="slack_thread", id=thread_id + ) + uris.append(uri) + processed_threads.add(thread_id) + else: + # Regular message or thread message - return message page + message_id = self.parser.encode_message_id(channel_id, ts) + logger.info( + f"Creating message URI for {message_id} (channel={channel_id}, ts={ts})" + ) + uri = PageURI( + root=self.context.root, + type="slack_message", + id=message_id, + ) + uris.append(uri) + + logger.info(f"Generated {len(uris)} URIs from search results") + + # Determine next page token + next_token = None + if pagination.get("page", 1) < pagination.get("page_count", 1): + next_token = str(pagination["page"] + 1) + + return uris, next_token + + except Exception as e: + logger.error(f"Error searching conversations: {e}") + return [], None + + @property + def toolkit(self) -> "SlackToolkit": + """Get the toolkit for this service.""" + return SlackToolkit(slack_service=self) + + @property + def name(self) -> str: + return "slack" + + +class SlackToolkit(RetrieverToolkit): + """Toolkit for retrieving Slack conversations and threads.""" + + def __init__(self, slack_service: SlackService): + super().__init__() + self.slack_service = slack_service + logger.info("Slack toolkit initialized") + + def _search_messages( + self, + query: str, + cursor: Optional[str] = None, + page_size: int = 10, + ) -> PaginatedResponse[SlackMessagePage]: + """Helper method to handle pagination for message searches.""" + # Pass cursor as positional argument, like Gmail service + uris, next_page_token = self.slack_service.search_messages( + query, page_token=cursor, page_size=page_size + ) + + pages = [] + + for uri in uris: + page = self.context.get_page(uri) + if isinstance(page, SlackMessagePage): + pages.append(page) + else: + logger.error(f"Resolved page is not SlackMessagePage: {type(page)}") + raise ValueError(f"Resolved page is not SlackMessagePage: {type(page)}") + + return PaginatedResponse( + results=pages, + next_cursor=next_page_token, + ) + + @tool() + def search_messages_by_content( + self, query: str, cursor: Optional[str] = None + ) -> PaginatedResponse[SlackMessagePage]: + """Search Slack conversations by content/keywords. + + Args: + query: Search query/keywords to find in conversation content + cursor: Pagination cursor for next page + page_size: Number of results per page + """ + return self._search_messages(query, cursor) + + @tool() + def search_messages_by_channel( + self, channel_name: str, cursor: Optional[str] = None + ) -> PaginatedResponse[SlackMessagePage]: + """Search conversations in a specific channel. + + Args: + channel_name: Name of the channel (without # prefix) + cursor: Pagination cursor for next page + page_size: Number of results per page + """ + # Clean channel name and build search query + clean_channel_name = channel_name.lstrip("#") + query = f"in:#{clean_channel_name}" + + return self._search_messages(query, cursor) + + @tool() + def search_messages_by_person( + self, person: str, cursor: Optional[str] = None + ) -> PaginatedResponse[SlackMessagePage]: + """Search conversations where a specific person participated. + + Args: + person: Person identifier (@username or user ID) + cursor: Pagination cursor for next page + """ + # Validate and format person identifier + validated_person = self.slack_service.parser.validate_person_identifier(person) + + # Build search query + query = f"from:{validated_person}" + + return self._search_messages(query, cursor) + + @tool() + def search_messages_by_date_range( + self, + start_date: str, + num_days: int, + cursor: Optional[str] = None, + ) -> PaginatedResponse[SlackMessagePage]: + """Search conversations within a date range. + + Args: + start_date: Start date in YYYY-MM-DD format + end_date: End date in YYYY-MM-DD format + cursor: Pagination cursor for next page + page_size: Number of results per page + """ + end_date = datetime.strptime(start_date, "%Y-%m-%d") + timedelta(days=num_days) + query = f"after:{start_date} before:{end_date.strftime('%Y-%m-%d')}" + + return self._search_messages(query, cursor) + + @tool() + def search_recent_messages( + self, days: int = 7, cursor: Optional[str] = None + ) -> PaginatedResponse[SlackMessagePage]: + """Search recent conversations from the last N days. + + Args: + days: Number of days to look back + cursor: Pagination cursor for next page + """ + date = datetime.now() - timedelta(days=days) + date_str = date.strftime("%Y-%m-%d") + query = f"after:{date_str}" + + return self._search_messages(query, cursor) + + @tool() + def search_direct_messages( + self, + person: Optional[str] = None, + cursor: Optional[str] = None, + ) -> PaginatedResponse[SlackMessagePage]: + """Search direct messages, optionally with a specific person. + + Args: + person: Optional person identifier (@username or user ID) for DMs (if None, returns all DMs) + cursor: Pagination cursor for next page + """ + if person: + validated_person = self.slack_service.parser.validate_person_identifier( + person + ) + query = f"in:{validated_person}" + else: + query = "in:@" # All DMs + + response = self._search_messages(query, cursor) + logger.info(f"Direct messages response: {response}") + print(response) + return response + + @tool() + def get_conversation_with_person( + self, person: str, cursor: Optional[str] = None + ) -> PaginatedResponse[SlackMessagePage]: + """Get conversations involving a specific person. + + Args: + person: Person identifier (@username or user ID) + cursor: Pagination cursor for next page + """ + # Validate and format person identifier + validated_person = self.slack_service.parser.validate_person_identifier(person) + query = validated_person # Already formatted for search + + return self._search_messages(query, cursor) + + @property + def name(self) -> str: + return "slack" diff --git a/src/pragweb/slack/utils.py b/src/pragweb/slack/utils.py new file mode 100644 index 0000000..eea8cf1 --- /dev/null +++ b/src/pragweb/slack/utils.py @@ -0,0 +1,186 @@ +"""Slack utility classes for parsing and formatting data.""" + +import logging +from datetime import datetime +from typing import Any, Callable, Dict, List + +logger = logging.getLogger(__name__) + + +class SlackParser: + """Parser for Slack data that handles message formatting, ID encoding, and content extraction.""" + + @staticmethod + def encode_message_id(channel_id: str, message_ts: str) -> str: + """Encode channel ID and message timestamp into a URI-safe message ID. + + Args: + channel_id: Slack channel ID + message_ts: Message timestamp + + Returns: + Encoded message ID safe for PageURI + """ + # Use underscore as separator since colons aren't allowed in PageURI IDs + return f"{channel_id}_{message_ts}" + + @staticmethod + def decode_message_id(message_id: str) -> tuple[str, str]: + """Decode message ID back to channel ID and message timestamp. + + Args: + message_id: Encoded message ID + + Returns: + Tuple of (channel_id, message_ts) + + Raises: + ValueError: If message ID format is invalid + """ + try: + # Split on last underscore to handle cases where channel_id might have underscores + parts = message_id.rsplit("_", 1) + if len(parts) != 2: + raise ValueError("Invalid format") + return parts[0], parts[1] + except ValueError: + raise ValueError( + f"Invalid message ID format: {message_id}. Expected 'channel_id_message_ts'" + ) + + @staticmethod + def encode_thread_id(channel_id: str, thread_ts: str) -> str: + """Encode channel ID and thread timestamp into a URI-safe thread ID. + + Args: + channel_id: Slack channel ID + thread_ts: Thread timestamp + + Returns: + Encoded thread ID safe for PageURI + """ + # Use underscore as separator since colons aren't allowed in PageURI IDs + return f"{channel_id}_{thread_ts}" + + @staticmethod + def decode_thread_id(thread_id: str) -> tuple[str, str]: + """Decode thread ID back to channel ID and thread timestamp. + + Args: + thread_id: Encoded thread ID + + Returns: + Tuple of (channel_id, thread_ts) + + Raises: + ValueError: If thread ID format is invalid + """ + try: + # Split on last underscore to handle cases where channel_id might have underscores + parts = thread_id.rsplit("_", 1) + if len(parts) != 2: + raise ValueError("Invalid format") + return parts[0], parts[1] + except ValueError: + raise ValueError( + f"Invalid thread ID format: {thread_id}. Expected 'channel_id_thread_ts'" + ) + + @staticmethod + def determine_channel_type(channel_info: Dict[str, Any]) -> str: + """Determine channel type string.""" + if channel_info.get("is_channel"): + return "public_channel" + elif channel_info.get("is_group"): + return "private_channel" + elif channel_info.get("is_im"): + return "im" + elif channel_info.get("is_mpim"): + return "mpim" + else: + return "unknown" + + @staticmethod + def get_user_display_name(user_info: Dict[str, Any]) -> str: + """Get user display name for UI.""" + # Prefer display name, then real name, then username + display_name = user_info.get("profile", {}).get("display_name") + if display_name: + return str(display_name) + + real_name = user_info.get("real_name") + if real_name: + return str(real_name) + + name = user_info.get("name", user_info.get("id", "unknown")) + return str(name) + + @staticmethod + def format_messages_for_content( + messages: List[Dict[str, Any]], get_user_display_name_fn: Callable[[str], str] + ) -> str: + """Format messages into readable content. + + Args: + messages: List of Slack message objects + get_user_display_name_fn: Function to get display name for user ID + """ + formatted_messages = [] + + for msg in messages: + user_id = msg.get("user", "unknown") + user_name = get_user_display_name_fn(user_id) + text = msg.get("text", "") + + # Format timestamp + timestamp = msg.get("ts", "") + if timestamp: + try: + dt = datetime.fromtimestamp(float(timestamp)) + time_str = dt.strftime("%Y-%m-%d %H:%M") + except (ValueError, TypeError): + time_str = timestamp + else: + time_str = "unknown" + + formatted_messages.append(f"[{time_str}] {user_name}: {text}") + + return "\n".join(formatted_messages) + + @staticmethod + def validate_person_identifier(person: str) -> str: + """Validate and format person identifier for Slack search. + + Args: + person: Person identifier (should start with @ or be a user ID) + + Returns: + Formatted person identifier for search + + Raises: + ValueError: If person identifier format is invalid + """ + if not person or not person.strip(): + raise ValueError("Person identifier cannot be empty") + + person = person.strip() + + # Handle @username format + if person.startswith("@"): + if len(person) == 1: # Just "@" + raise ValueError("Username cannot be empty after @") + username = person[1:] + # Basic validation: username should contain alphanumeric chars, dots, dashes, underscores + if not username or not all(c.isalnum() or c in "._-" for c in username): + raise ValueError(f"Invalid username format: {person}") + return person + + # Handle Slack user ID format (U + 10 alphanumeric chars = 11 total) + if person.startswith("U") and len(person) == 11 and person[1:].isalnum(): + return f"<@{person}>" + + # Reject everything else + raise ValueError( + f"Person identifier '{person}' is invalid. " + f"Use @username (e.g., '@john.doe') or user ID (e.g., 'U1234567890') format." + ) diff --git a/tests/core/test_page_cache.py b/tests/core/test_page_cache.py index 4bd06c2..4aa68f3 100644 --- a/tests/core/test_page_cache.py +++ b/tests/core/test_page_cache.py @@ -5,10 +5,11 @@ """ import tempfile -from datetime import datetime -from typing import Any, Optional +from datetime import datetime, timedelta, timezone +from typing import Any, List, Optional import pytest +from pydantic import BaseModel, Field, ValidationError from praga_core.page_cache import PageCache from praga_core.types import Page, PageURI @@ -704,29 +705,25 @@ def test_different_uris_different_records(self, page_cache: PageCache) -> None: class TestPageURISerialization: """Test PageURI serialization and deserialization in page_cache.""" - def test_convert_page_uris_for_storage_single_uri( - self, page_cache: PageCache - ) -> None: + def test_serialize_for_storage_single_uri(self, page_cache: PageCache) -> None: """Test converting a single PageURI to string for storage.""" uri = PageURI(root="test", type="doc", id="123") - result = page_cache._convert_page_uris_for_storage(uri) + result = page_cache._serialize_for_storage(uri) assert result == str(uri) assert isinstance(result, str) - def test_convert_page_uris_for_storage_list_of_uris( - self, page_cache: PageCache - ) -> None: + def test_serialize_for_storage_list_of_uris(self, page_cache: PageCache) -> None: """Test converting a list of PageURIs to strings for storage.""" uris = [ PageURI(root="test", type="doc", id="123"), PageURI(root="test", type="doc", id="456"), ] - result = page_cache._convert_page_uris_for_storage(uris) + result = page_cache._serialize_for_storage(uris) expected = [str(uri) for uri in uris] assert result == expected assert all(isinstance(item, str) for item in result) - def test_convert_page_uris_for_storage_nested_structure( + def test_serialize_for_storage_nested_structure( self, page_cache: PageCache ) -> None: """Test converting nested structures containing PageURIs.""" @@ -740,16 +737,14 @@ def test_convert_page_uris_for_storage_nested_structure( "number": 42, } - result = page_cache._convert_page_uris_for_storage(nested_data) + result = page_cache._serialize_for_storage(nested_data) assert result["single_uri"] == str(uri1) assert result["uri_list"] == [str(uri1), str(uri2)] assert result["regular_data"] == "some string" assert result["number"] == 42 - def test_convert_page_uris_for_storage_non_uri_values( - self, page_cache: PageCache - ) -> None: + def test_serialize_for_storage_non_uri_values(self, page_cache: PageCache) -> None: """Test that non-PageURI values are returned unchanged.""" test_values = [ "string", @@ -762,26 +757,22 @@ def test_convert_page_uris_for_storage_non_uri_values( ] for value in test_values: - result = page_cache._convert_page_uris_for_storage(value) + result = page_cache._serialize_for_storage(value) assert result == value - def test_convert_page_uris_from_storage_single_uri( - self, page_cache: PageCache - ) -> None: + def test_deserialize_from_storage_single_uri(self, page_cache: PageCache) -> None: """Test converting a string back to PageURI from storage.""" from praga_core.types import PageURI uri_string = "test/doc:123@1" - result = page_cache._convert_page_uris_from_storage(uri_string, PageURI) + result = page_cache._deserialize_from_storage(uri_string, PageURI) assert isinstance(result, PageURI) assert result.root == "test" assert result.type == "doc" assert result.id == "123" - def test_convert_page_uris_from_storage_optional_uri( - self, page_cache: PageCache - ) -> None: + def test_deserialize_from_storage_optional_uri(self, page_cache: PageCache) -> None: """Test converting Optional[PageURI] from storage.""" from typing import Optional @@ -789,27 +780,21 @@ def test_convert_page_uris_from_storage_optional_uri( # Test with actual URI string uri_string = "test/doc:123@1" - result = page_cache._convert_page_uris_from_storage( - uri_string, Optional[PageURI] - ) + result = page_cache._deserialize_from_storage(uri_string, Optional[PageURI]) assert isinstance(result, PageURI) # Test with None - result_none = page_cache._convert_page_uris_from_storage( - None, Optional[PageURI] - ) + result_none = page_cache._deserialize_from_storage(None, Optional[PageURI]) assert result_none is None - def test_convert_page_uris_from_storage_list_of_uris( - self, page_cache: PageCache - ) -> None: + def test_deserialize_from_storage_list_of_uris(self, page_cache: PageCache) -> None: """Test converting List[PageURI] from storage.""" from typing import List from praga_core.types import PageURI uri_strings = ["test/doc:123@1", "test/doc:456@1"] - result = page_cache._convert_page_uris_from_storage(uri_strings, List[PageURI]) + result = page_cache._deserialize_from_storage(uri_strings, List[PageURI]) assert isinstance(result, list) assert len(result) == 2 @@ -817,7 +802,7 @@ def test_convert_page_uris_from_storage_list_of_uris( assert result[0].id == "123" assert result[1].id == "456" - def test_convert_page_uris_from_storage_non_uri_types( + def test_deserialize_from_storage_non_uri_types( self, page_cache: PageCache ) -> None: """Test that non-PageURI types are returned unchanged.""" @@ -831,7 +816,7 @@ def test_convert_page_uris_from_storage_non_uri_types( ] for value, field_type in test_cases: - result = page_cache._convert_page_uris_from_storage(value, field_type) + result = page_cache._deserialize_from_storage(value, field_type) assert result == value @@ -1125,3 +1110,506 @@ def test_find_chunks_by_document_id(self, page_cache: PageCache) -> None: assert len(results) == 3 chunk_indices = {chunk.chunk_index for chunk in results} assert chunk_indices == {0, 1, 2} + + +class TestPydanticModelSerialization: + """Test Pydantic model serialization and deserialization in PageCache.""" + + class SlackMessageSummary(BaseModel): + """Test Pydantic model with datetime fields (like SlackMessageSummary).""" + + display_name: str = Field(description="User who sent the message") + text: str = Field(description="Message text content") + timestamp: datetime = Field(description="Message timestamp") + + class SlackThreadPage(Page): + """Test page with list of Pydantic models.""" + + thread_ts: str = Field(description="Thread timestamp", exclude=True) + channel_name: str = Field(description="Channel name") + messages: List["TestPydanticModelSerialization.SlackMessageSummary"] = Field( + description="Messages in thread" + ) + message_count: int = Field(description="Number of messages") + participants: List[str] = Field(description="Participants") + + def __init__(self, **data: Any) -> None: + super().__init__(**data) + self._metadata.token_count = len(data.get("channel_name", "")) // 4 + + class BlogPost(Page): + """Test page with optional Pydantic model field.""" + + title: str = Field(description="Post title") + content: str = Field(description="Post content") + author: Optional["TestPydanticModelSerialization.SlackMessageSummary"] = Field( + None, description="Author info" + ) + + def __init__(self, **data: Any) -> None: + super().__init__(**data) + self._metadata.token_count = len(data.get("content", "")) // 4 + + def test_serialize_for_storage_pydantic_model(self, page_cache: PageCache) -> None: + """Test serializing a Pydantic model for storage.""" + message = self.SlackMessageSummary( + display_name="Alice", + text="Hello world", + timestamp=datetime(2023, 6, 15, 10, 30, 0, tzinfo=timezone.utc), + ) + + result = page_cache._serialize_for_storage(message) + + assert isinstance(result, dict) + assert result["display_name"] == "Alice" + assert result["text"] == "Hello world" + # datetime should be serialized as ISO string in JSON mode + assert result["timestamp"] == "2023-06-15T10:30:00Z" + + def test_serialize_for_storage_list_of_pydantic_models( + self, page_cache: PageCache + ) -> None: + """Test serializing a list of Pydantic models for storage.""" + messages = [ + self.SlackMessageSummary( + display_name="Alice", + text="First message", + timestamp=datetime(2023, 6, 15, 10, 30, 0, tzinfo=timezone.utc), + ), + self.SlackMessageSummary( + display_name="Bob", + text="Second message", + timestamp=datetime(2023, 6, 15, 10, 31, 0, tzinfo=timezone.utc), + ), + ] + + result = page_cache._serialize_for_storage(messages) + + assert isinstance(result, list) + assert len(result) == 2 + assert all(isinstance(item, dict) for item in result) + assert result[0]["display_name"] == "Alice" + assert result[1]["display_name"] == "Bob" + assert result[0]["timestamp"] == "2023-06-15T10:30:00Z" + assert result[1]["timestamp"] == "2023-06-15T10:31:00Z" + + def test_deserialize_from_storage_pydantic_model( + self, page_cache: PageCache + ) -> None: + """Test deserializing a Pydantic model from storage.""" + stored_data = { + "display_name": "Alice", + "text": "Hello world", + "timestamp": "2023-06-15T10:30:00Z", + } + + result = page_cache._deserialize_from_storage( + stored_data, self.SlackMessageSummary + ) + + assert isinstance(result, self.SlackMessageSummary) + assert result.display_name == "Alice" + assert result.text == "Hello world" + assert result.timestamp == datetime(2023, 6, 15, 10, 30, 0, tzinfo=timezone.utc) + + def test_deserialize_from_storage_list_of_pydantic_models( + self, page_cache: PageCache + ) -> None: + """Test deserializing a list of Pydantic models from storage.""" + from typing import List + + stored_data = [ + { + "display_name": "Alice", + "text": "First message", + "timestamp": "2023-06-15T10:30:00Z", + }, + { + "display_name": "Bob", + "text": "Second message", + "timestamp": "2023-06-15T10:31:00Z", + }, + ] + + result = page_cache._deserialize_from_storage( + stored_data, List[self.SlackMessageSummary] + ) + + assert isinstance(result, list) + assert len(result) == 2 + assert all(isinstance(item, self.SlackMessageSummary) for item in result) + assert result[0].display_name == "Alice" + assert result[1].display_name == "Bob" + assert result[0].timestamp == datetime( + 2023, 6, 15, 10, 30, 0, tzinfo=timezone.utc + ) + assert result[1].timestamp == datetime( + 2023, 6, 15, 10, 31, 0, tzinfo=timezone.utc + ) + + def test_deserialize_from_storage_optional_pydantic_model( + self, page_cache: PageCache + ) -> None: + """Test deserializing Optional[PydanticModel] from storage.""" + from typing import Optional + + # Test with actual data + stored_data = { + "display_name": "Alice", + "text": "Hello", + "timestamp": "2023-06-15T10:30:00Z", + } + + result = page_cache._deserialize_from_storage( + stored_data, Optional[self.SlackMessageSummary] + ) + + assert isinstance(result, self.SlackMessageSummary) + assert result.display_name == "Alice" + + # Test with None + result_none = page_cache._deserialize_from_storage( + None, Optional[self.SlackMessageSummary] + ) + assert result_none is None + + def test_is_pydantic_model_type(self, page_cache: PageCache) -> None: + """Test the _is_pydantic_model_type helper method.""" + from pydantic import BaseModel + + # Should return True for Pydantic models + assert page_cache._is_pydantic_model_type(self.SlackMessageSummary) is True + assert page_cache._is_pydantic_model_type(BaseModel) is True + + # Should return False for non-Pydantic types + assert page_cache._is_pydantic_model_type(str) is False + assert page_cache._is_pydantic_model_type(int) is False + assert page_cache._is_pydantic_model_type(dict) is False + assert page_cache._is_pydantic_model_type(list) is False + assert page_cache._is_pydantic_model_type(None) is False + + def test_store_and_retrieve_page_with_pydantic_models( + self, page_cache: PageCache + ) -> None: + """Test end-to-end storage and retrieval of page with Pydantic model fields.""" + messages = [ + self.SlackMessageSummary( + display_name="Alice", + text="Thread starter message", + timestamp=datetime(2023, 6, 15, 10, 30, 0, tzinfo=timezone.utc), + ), + self.SlackMessageSummary( + display_name="Bob", + text="Reply to thread", + timestamp=datetime(2023, 6, 15, 10, 31, 0, tzinfo=timezone.utc), + ), + self.SlackMessageSummary( + display_name="Charlie", + text="Another reply", + timestamp=datetime(2023, 6, 15, 10, 32, 0, tzinfo=timezone.utc), + ), + ] + + thread_page = self.SlackThreadPage( + uri=PageURI(root="test", type="slack_thread", id="thread123"), + thread_ts="1687009800.001", + channel_name="general", + messages=messages, + message_count=3, + participants=["Alice", "Bob", "Charlie"], + ) + + # Store the page + result = page_cache.store_page(thread_page) + assert result is True + + # Retrieve the page + retrieved_page = page_cache.get_page(self.SlackThreadPage, thread_page.uri) + + assert retrieved_page is not None + assert retrieved_page.channel_name == "general" + assert retrieved_page.message_count == 3 + assert retrieved_page.participants == ["Alice", "Bob", "Charlie"] + + # Verify Pydantic models are properly deserialized + assert isinstance(retrieved_page.messages, list) + assert len(retrieved_page.messages) == 3 + assert all( + isinstance(msg, self.SlackMessageSummary) for msg in retrieved_page.messages + ) + + # Check first message + msg1 = retrieved_page.messages[0] + assert msg1.display_name == "Alice" + assert msg1.text == "Thread starter message" + assert msg1.timestamp == datetime(2023, 6, 15, 10, 30, 0, tzinfo=timezone.utc) + + # Check second message + msg2 = retrieved_page.messages[1] + assert msg2.display_name == "Bob" + assert msg2.text == "Reply to thread" + assert msg2.timestamp == datetime(2023, 6, 15, 10, 31, 0, tzinfo=timezone.utc) + + def test_update_page_with_pydantic_models(self, page_cache: PageCache) -> None: + """Test updating a page with modified Pydantic model fields.""" + initial_message = self.SlackMessageSummary( + display_name="Alice", + text="Initial message", + timestamp=datetime(2023, 6, 15, 10, 30, 0, tzinfo=timezone.utc), + ) + + thread_page = self.SlackThreadPage( + uri=PageURI(root="test", type="slack_thread", id="thread456"), + thread_ts="1687009800.002", + channel_name="general", + messages=[initial_message], + message_count=1, + participants=["Alice"], + ) + + # Store initial page + page_cache.store_page(thread_page) + + # Update with additional messages + updated_messages = [ + initial_message, + self.SlackMessageSummary( + display_name="Bob", + text="Added reply", + timestamp=datetime(2023, 6, 15, 10, 31, 0, tzinfo=timezone.utc), + ), + ] + + updated_page = self.SlackThreadPage( + uri=thread_page.uri, # Same URI + thread_ts="1687009800.002", + channel_name="general", + messages=updated_messages, + message_count=2, + participants=["Alice", "Bob"], + ) + + # Update the page + result = page_cache.store_page(updated_page) + assert result is False # Should be update, not insert + + # Retrieve and verify + retrieved_page = page_cache.get_page(self.SlackThreadPage, thread_page.uri) + + assert retrieved_page is not None + assert retrieved_page.message_count == 2 + assert len(retrieved_page.messages) == 2 + assert retrieved_page.participants == ["Alice", "Bob"] + + # Verify both messages are present and correctly deserialized + assert retrieved_page.messages[0].display_name == "Alice" + assert retrieved_page.messages[0].text == "Initial message" + assert retrieved_page.messages[1].display_name == "Bob" + assert retrieved_page.messages[1].text == "Added reply" + + def test_find_pages_by_pydantic_model_content(self, page_cache: PageCache) -> None: + """Test finding pages based on content within Pydantic model fields.""" + # Create pages with different message contents + thread1 = self.SlackThreadPage( + uri=PageURI(root="test", type="slack_thread", id="thread1"), + thread_ts="1687009800.001", + channel_name="general", + messages=[ + self.SlackMessageSummary( + display_name="Alice", + text="urgent task needs attention", + timestamp=datetime(2023, 6, 15, 10, 30, 0, tzinfo=timezone.utc), + ) + ], + message_count=1, + participants=["Alice"], + ) + + thread2 = self.SlackThreadPage( + uri=PageURI(root="test", type="slack_thread", id="thread2"), + thread_ts="1687009800.002", + channel_name="general", + messages=[ + self.SlackMessageSummary( + display_name="Bob", + text="casual conversation about lunch", + timestamp=datetime(2023, 6, 15, 10, 30, 0, tzinfo=timezone.utc), + ) + ], + message_count=1, + participants=["Bob"], + ) + + # Store both pages + page_cache.store_page(thread1) + page_cache.store_page(thread2) + + # Find pages by channel name (direct field) + results = page_cache.find_pages_by_attribute( + self.SlackThreadPage, lambda t: t.channel_name == "general" + ) + + assert len(results) == 2 + thread_ids = {page.uri.id for page in results} + assert thread_ids == {"thread1", "thread2"} + + def test_optional_pydantic_model_field(self, page_cache: PageCache) -> None: + """Test storing and retrieving pages with optional Pydantic model fields.""" + # Test with None author + post_without_author = self.BlogPost( + uri=PageURI(root="test", type="blog_post", id="post1"), + title="Anonymous Post", + content="This post has no author information", + author=None, + ) + + # Test with author + author_info = self.SlackMessageSummary( + display_name="Alice", + text="Author bio", + timestamp=datetime(2023, 6, 15, 10, 30, 0, tzinfo=timezone.utc), + ) + + post_with_author = self.BlogPost( + uri=PageURI(root="test", type="blog_post", id="post2"), + title="Authored Post", + content="This post has author information", + author=author_info, + ) + + # Store both posts + page_cache.store_page(post_without_author) + page_cache.store_page(post_with_author) + + # Retrieve post without author + retrieved_post1 = page_cache.get_page(self.BlogPost, post_without_author.uri) + assert retrieved_post1 is not None + assert retrieved_post1.title == "Anonymous Post" + assert retrieved_post1.author is None + + # Retrieve post with author + retrieved_post2 = page_cache.get_page(self.BlogPost, post_with_author.uri) + assert retrieved_post2 is not None + assert retrieved_post2.title == "Authored Post" + assert isinstance(retrieved_post2.author, self.SlackMessageSummary) + assert retrieved_post2.author.display_name == "Alice" + assert retrieved_post2.author.text == "Author bio" + assert retrieved_post2.author.timestamp == datetime( + 2023, 6, 15, 10, 30, 0, tzinfo=timezone.utc + ) + + def test_nested_pydantic_models(self, page_cache: PageCache) -> None: + """Test storing complex nested structures with Pydantic models.""" + nested_data = { + "thread_info": { + "messages": [ + self.SlackMessageSummary( + display_name="Alice", + text="Nested message", + timestamp=datetime(2023, 6, 15, 10, 30, 0, tzinfo=timezone.utc), + ) + ], + "metadata": {"channel": "general", "priority": "high"}, + }, + "regular_field": "simple value", + } + + # Test serialization + serialized = page_cache._serialize_for_storage(nested_data) + + assert isinstance(serialized, dict) + assert serialized["regular_field"] == "simple value" + assert isinstance(serialized["thread_info"], dict) + assert isinstance(serialized["thread_info"]["messages"], list) + assert isinstance(serialized["thread_info"]["messages"][0], dict) + assert serialized["thread_info"]["messages"][0]["display_name"] == "Alice" + assert ( + serialized["thread_info"]["messages"][0]["timestamp"] + == "2023-06-15T10:30:00Z" + ) + + def test_datetime_timezone_handling(self, page_cache: PageCache) -> None: + """Test proper timezone handling in Pydantic model serialization.""" + # Test with UTC timezone + utc_message = self.SlackMessageSummary( + display_name="Alice", + text="UTC message", + timestamp=datetime(2023, 6, 15, 10, 30, 0, tzinfo=timezone.utc), + ) + + # Test with different timezone + est_tz = timezone(timedelta(hours=-5)) + est_message = self.SlackMessageSummary( + display_name="Bob", + text="EST message", + timestamp=datetime(2023, 6, 15, 5, 30, 0, tzinfo=est_tz), + ) + + # Serialize both + utc_serialized = page_cache._serialize_for_storage(utc_message) + est_serialized = page_cache._serialize_for_storage(est_message) + + assert utc_serialized["timestamp"] == "2023-06-15T10:30:00Z" + assert est_serialized["timestamp"] == "2023-06-15T05:30:00-05:00" + + # Deserialize both + utc_deserialized = page_cache._deserialize_from_storage( + utc_serialized, self.SlackMessageSummary + ) + est_deserialized = page_cache._deserialize_from_storage( + est_serialized, self.SlackMessageSummary + ) + + assert utc_deserialized.timestamp == datetime( + 2023, 6, 15, 10, 30, 0, tzinfo=timezone.utc + ) + assert est_deserialized.timestamp == datetime( + 2023, 6, 15, 5, 30, 0, tzinfo=est_tz + ) + + +class TestPydanticModelErrorScenarios: + """Test error handling for Pydantic model serialization.""" + + class InvalidMessage(BaseModel): + """Test model with validation that can fail.""" + + display_name: str + score: int = Field(ge=0, le=100) # Must be between 0 and 100 + + def test_invalid_pydantic_model_data(self, page_cache: PageCache) -> None: + """Test handling of invalid data during Pydantic model deserialization.""" + invalid_data = { + "display_name": "Alice", + "score": 150, # Invalid - exceeds max value + } + + # Should raise ValidationError during deserialization + with pytest.raises(ValidationError): + page_cache._deserialize_from_storage(invalid_data, self.InvalidMessage) + + def test_missing_required_fields(self, page_cache: PageCache) -> None: + """Test handling of missing required fields in Pydantic models.""" + incomplete_data = { + "display_name": "Alice" + # Missing required 'score' field + } + + with pytest.raises(ValidationError): + page_cache._deserialize_from_storage(incomplete_data, self.InvalidMessage) + + def test_wrong_data_type_for_pydantic_model(self, page_cache: PageCache) -> None: + """Test handling wrong data types during deserialization.""" + # Pass a string instead of dict for Pydantic model + result = page_cache._deserialize_from_storage("not_a_dict", self.InvalidMessage) + + # Should return the original value unchanged since it's not a dict + assert result == "not_a_dict" + + def test_non_pydantic_type_with_dict_data(self, page_cache: PageCache) -> None: + """Test that non-Pydantic types don't get processed as Pydantic models.""" + dict_data = {"key": "value"} + + # Should return unchanged since str is not a Pydantic model type + result = page_cache._deserialize_from_storage(dict_data, str) + assert result == dict_data diff --git a/tests/services/test_slack_api_client.py b/tests/services/test_slack_api_client.py new file mode 100644 index 0000000..88a973d --- /dev/null +++ b/tests/services/test_slack_api_client.py @@ -0,0 +1,763 @@ +"""Tests for SlackAPIClient.""" + +from unittest.mock import Mock + +import pytest +from slack_sdk.errors import SlackApiError +from slack_sdk.web.slack_response import SlackResponse + +from pragweb.slack.client import SlackAPIClient + + +class TestSlackAPIClient: + """Test suite for SlackAPIClient.""" + + def setup_method(self): + """Set up test environment.""" + # Create mock auth manager + self.mock_auth_manager = Mock() + self.mock_web_client = Mock() + self.mock_auth_manager.get_client.return_value = self.mock_web_client + + self.client = SlackAPIClient(self.mock_auth_manager) + + def test_init(self): + """Test SlackAPIClient initialization.""" + # Test with custom auth manager + assert self.client.auth_manager is self.mock_auth_manager + + # Test with default auth manager + default_client = SlackAPIClient() + assert default_client.auth_manager is not None + + def test_client_property(self): + """Test client property lazy loading.""" + # First access should call get_client + client = self.client.client + assert client is self.mock_web_client + self.mock_auth_manager.get_client.assert_called_once() + + # Second access should use cached client + client2 = self.client.client + assert client2 is self.mock_web_client + # Should still only be called once + self.mock_auth_manager.get_client.assert_called_once() + + def test_get_channel_info_success(self): + """Test successful channel info retrieval.""" + # Mock successful response + mock_channel_data = { + "id": "C1234567890", + "name": "general", + "is_channel": True, + "is_group": False, + "is_im": False, + "is_mpim": False, + "topic": {"value": "General discussion"}, + "purpose": {"value": "Company-wide announcements"}, + "created": 1234567890, + "is_archived": False, + } + + mock_response = Mock(spec=SlackResponse) + mock_response.data = {"ok": True, "channel": mock_channel_data} + self.mock_web_client.conversations_info.return_value = mock_response + + # Call method + result = self.client.get_channel_info("C1234567890") + + # Verify API call + self.mock_web_client.conversations_info.assert_called_once_with( + channel="C1234567890" + ) + + # Verify result + assert result == mock_channel_data + + def test_get_channel_info_error(self): + """Test channel info retrieval with error.""" + # Mock error response + mock_response = Mock(spec=SlackResponse) + mock_response.data = {"ok": False, "error": "channel_not_found"} + self.mock_web_client.conversations_info.return_value = mock_response + + # Should raise SlackApiError + with pytest.raises(SlackApiError): + self.client.get_channel_info("C1234567890") + + def test_get_channel_info_invalid_data(self): + """Test channel info retrieval with invalid channel data.""" + # Mock response with invalid channel data + mock_response = Mock(spec=SlackResponse) + mock_response.data = {"ok": True, "channel": "invalid_data"} + self.mock_web_client.conversations_info.return_value = mock_response + + # Should raise ValueError + with pytest.raises(ValueError, match="Invalid channel data received"): + self.client.get_channel_info("C1234567890") + + def test_list_channels_success(self): + """Test successful channel list retrieval.""" + # Mock channel data + mock_channels = [ + { + "id": "C1234567890", + "name": "general", + "is_channel": True, + "is_member": True, + }, + { + "id": "C0987654321", + "name": "random", + "is_channel": True, + "is_member": True, + }, + { + "id": "C1111111111", + "name": "private-group", + "is_group": True, + }, + ] + + mock_response = Mock(spec=SlackResponse) + mock_response.data = { + "ok": True, + "channels": mock_channels, + "response_metadata": {"next_cursor": ""}, + } + self.mock_web_client.conversations_list.return_value = mock_response + + # Call method + result = self.client.list_channels() + + # Verify API call + self.mock_web_client.conversations_list.assert_called_once() + + # Verify result - should include all channels where user is member + assert len(result) == 3 # All channels should be included + + def test_list_channels_with_pagination(self): + """Test channel list with pagination.""" + # Mock first page + mock_channels_page1 = [ + { + "id": "C1234567890", + "name": "general", + "is_channel": True, + "is_member": True, + } + ] + mock_response_page1 = Mock(spec=SlackResponse) + mock_response_page1.data = { + "ok": True, + "channels": mock_channels_page1, + "response_metadata": {"next_cursor": "cursor123"}, + } + + # Mock second page + mock_channels_page2 = [ + { + "id": "C0987654321", + "name": "random", + "is_channel": True, + "is_member": True, + } + ] + mock_response_page2 = Mock(spec=SlackResponse) + mock_response_page2.data = { + "ok": True, + "channels": mock_channels_page2, + "response_metadata": {"next_cursor": ""}, + } + + # Set up side_effect for multiple calls + self.mock_web_client.conversations_list.side_effect = [ + mock_response_page1, + mock_response_page2, + ] + + # Call method + result = self.client.list_channels() + + # Verify multiple API calls + assert self.mock_web_client.conversations_list.call_count == 2 + + # Verify result combines both pages + assert len(result) == 2 + + def test_list_channels_member_filtering(self): + """Test that list_channels filters to only channels where user is member.""" + # Mock channels with mixed membership + mock_channels = [ + { + "id": "C1234567890", + "name": "general", + "is_channel": True, + "is_member": True, # User is member + }, + { + "id": "C0987654321", + "name": "other-team", + "is_channel": True, + "is_member": False, # User is NOT member + }, + { + "id": "G1111111111", + "name": "private-group", + "is_group": True, + # Private groups don't have is_member field + }, + ] + + mock_response = Mock(spec=SlackResponse) + mock_response.data = { + "ok": True, + "channels": mock_channels, + "response_metadata": {"next_cursor": ""}, + } + self.mock_web_client.conversations_list.return_value = mock_response + + # Call method + result = self.client.list_channels() + + # Should only include channels where user is member + private groups + assert len(result) == 2 + channel_ids = [ch["id"] for ch in result] + assert "C1234567890" in channel_ids # Member of public channel + assert "C0987654321" not in channel_ids # Not member of public channel + assert "G1111111111" in channel_ids # Private group (always included) + + def test_get_channel_members_success(self): + """Test successful channel members retrieval.""" + # Mock members data + mock_members = ["U1234567890", "U0987654321", "U1111111111"] + + mock_response = Mock(spec=SlackResponse) + mock_response.data = { + "ok": True, + "members": mock_members, + "response_metadata": {"next_cursor": ""}, + } + self.mock_web_client.conversations_members.return_value = mock_response + + # Call method + result = self.client.get_channel_members("C1234567890") + + # Verify API call + self.mock_web_client.conversations_members.assert_called_once_with( + channel="C1234567890", cursor=None + ) + + # Verify result + assert result == mock_members + + def test_get_conversation_history_success(self): + """Test successful conversation history retrieval.""" + # Mock message data + mock_messages = [ + { + "ts": "1234567890.001", + "user": "U1234567890", + "text": "Hello everyone!", + }, + { + "ts": "1234567890.002", + "user": "U0987654321", + "text": "Hi there!", + }, + ] + + mock_response = Mock(spec=SlackResponse) + mock_response.data = { + "ok": True, + "messages": mock_messages, + "response_metadata": {"next_cursor": "cursor123"}, + } + self.mock_web_client.conversations_history.return_value = mock_response + + # Call method + messages, next_cursor = self.client.get_conversation_history( + channel_id="C1234567890", + oldest="1234567880.000", + latest="1234567900.000", + inclusive=True, + limit=50, + ) + + # Verify API call + self.mock_web_client.conversations_history.assert_called_once_with( + channel="C1234567890", + limit=50, + oldest="1234567880.000", + latest="1234567900.000", + cursor=None, + inclusive=True, + ) + + # Verify result + assert messages == mock_messages + assert next_cursor == "cursor123" + + def test_get_conversation_history_error(self): + """Test conversation history retrieval with error.""" + # Mock error response + mock_response = Mock(spec=SlackResponse) + mock_response.data = {"ok": False, "error": "channel_not_found"} + self.mock_web_client.conversations_history.return_value = mock_response + + # Should raise SlackApiError + with pytest.raises(SlackApiError): + self.client.get_conversation_history("C1234567890") + + def test_get_thread_replies_success(self): + """Test successful thread replies retrieval.""" + # Mock thread messages + mock_messages = [ + { + "ts": "1234567890.001", + "user": "U1234567890", + "text": "Parent message", + }, + { + "ts": "1234567890.002", + "user": "U0987654321", + "text": "Reply message", + }, + ] + + mock_response = Mock(spec=SlackResponse) + mock_response.data = {"ok": True, "messages": mock_messages} + self.mock_web_client.conversations_replies.return_value = mock_response + + # Call method + result = self.client.get_thread_replies("C1234567890", "1234567890.001") + + # Verify API call + self.mock_web_client.conversations_replies.assert_called_once_with( + channel="C1234567890", ts="1234567890.001" + ) + + # Verify result + assert result == mock_messages + + def test_search_messages_success(self): + """Test successful message search.""" + # Mock search results + mock_messages = [ + { + "ts": "1234567890.001", + "user": "U1234567890", + "text": "Found message 1", + "channel": {"id": "C1234567890"}, + }, + { + "ts": "1234567890.002", + "user": "U0987654321", + "text": "Found message 2", + "channel": {"id": "C0987654321"}, + }, + ] + + mock_pagination = { + "page": 1, + "page_count": 1, + "total": 2, + } + + mock_response = Mock(spec=SlackResponse) + mock_response.data = { + "ok": True, + "messages": { + "matches": mock_messages, + "pagination": mock_pagination, + }, + } + self.mock_web_client.search_messages.return_value = mock_response + + # Call method + messages, pagination = self.client.search_messages( + query="test query", sort="timestamp", sort_dir="desc", count=20, page=1 + ) + + # Verify API call + self.mock_web_client.search_messages.assert_called_once_with( + query="test query", sort="timestamp", sort_dir="desc", count=20, page=1 + ) + + # Verify result + assert messages == mock_messages + assert pagination == mock_pagination + + def test_search_messages_in_channel_success(self): + """Test successful message search within channel.""" + # Mock search results + mock_messages = [ + { + "ts": "1234567890.001", + "user": "U1234567890", + "text": "Channel message 1", + "channel": {"id": "C1234567890"}, + }, + ] + + mock_response = Mock(spec=SlackResponse) + mock_response.data = { + "ok": True, + "messages": { + "matches": mock_messages, + "pagination": {}, + }, + } + self.mock_web_client.search_messages.return_value = mock_response + + # Call method + result = self.client.search_messages_in_channel( + channel_id="C1234567890", + query="test content", + oldest="1234567880.000", + latest="1234567900.000", + limit=50, + ) + + # Verify API call was made with channel filter + self.mock_web_client.search_messages.assert_called_once() + call_args = self.mock_web_client.search_messages.call_args + query_arg = call_args[1]["query"] + + # Should include channel filter and content + assert "test content" in query_arg + assert "in:<#C1234567890>" in query_arg + assert "after:" in query_arg # Date filters + assert "before:" in query_arg + + # Verify result + assert result == mock_messages + + def test_get_user_info_success(self): + """Test successful user info retrieval.""" + # Mock user data + mock_user_data = { + "id": "U1234567890", + "name": "alice", + "real_name": "Alice Smith", + "is_bot": False, + "is_admin": True, + "profile": { + "display_name": "Alice", + "email": "alice@example.com", + "title": "Software Engineer", + }, + } + + mock_response = Mock(spec=SlackResponse) + mock_response.data = {"ok": True, "user": mock_user_data} + self.mock_web_client.users_info.return_value = mock_response + + # Call method + result = self.client.get_user_info("U1234567890") + + # Verify API call + self.mock_web_client.users_info.assert_called_once_with(user="U1234567890") + + # Verify result + assert result == mock_user_data + + def test_get_user_info_error(self): + """Test user info retrieval with error.""" + # Mock error response + mock_response = Mock(spec=SlackResponse) + mock_response.data = {"ok": False, "error": "user_not_found"} + self.mock_web_client.users_info.return_value = mock_response + + # Should raise SlackApiError + with pytest.raises(SlackApiError): + self.client.get_user_info("U1234567890") + + def test_list_users_success(self): + """Test successful user list retrieval.""" + # Mock user data + mock_users = [ + { + "id": "U1234567890", + "name": "alice", + "real_name": "Alice Smith", + }, + { + "id": "U0987654321", + "name": "bob", + "real_name": "Bob Jones", + }, + ] + + mock_response = Mock(spec=SlackResponse) + mock_response.data = { + "ok": True, + "members": mock_users, + "response_metadata": {"next_cursor": ""}, + } + self.mock_web_client.users_list.return_value = mock_response + + # Call method + result = self.client.list_users() + + # Verify API call + self.mock_web_client.users_list.assert_called_once() + + # Verify result + assert result == mock_users + + def test_lookup_user_by_email_success(self): + """Test successful user lookup by email.""" + # Mock user data + mock_user_data = { + "id": "U1234567890", + "name": "alice", + "profile": {"email": "alice@example.com"}, + } + + mock_response = Mock(spec=SlackResponse) + mock_response.data = {"ok": True, "user": mock_user_data} + self.mock_web_client.users_lookupByEmail.return_value = mock_response + + # Call method + result = self.client.lookup_user_by_email("alice@example.com") + + # Verify API call + self.mock_web_client.users_lookupByEmail.assert_called_once_with( + email="alice@example.com" + ) + + # Verify result + assert result == mock_user_data + + def test_lookup_user_by_email_not_found(self): + """Test user lookup by email when user not found.""" + # Mock not found response + mock_response = Mock(spec=SlackResponse) + mock_response.data = {"ok": False, "error": "users_not_found"} + self.mock_web_client.users_lookupByEmail.return_value = mock_response + + # Should return None for not found + result = self.client.lookup_user_by_email("nonexistent@example.com") + assert result is None + + def test_lookup_user_by_email_other_error(self): + """Test user lookup by email with other error.""" + # Mock other error response + mock_response = Mock(spec=SlackResponse) + mock_response.data = {"ok": False, "error": "invalid_email"} + self.mock_web_client.users_lookupByEmail.return_value = mock_response + + # Should raise SlackApiError for other errors + with pytest.raises(SlackApiError): + self.client.lookup_user_by_email("invalid-email") + + def test_test_auth_success(self): + """Test successful auth test.""" + # Mock auth response + mock_auth_data = { + "ok": True, + "url": "https://example.slack.com/", + "team": "Example Team", + "user": "alice", + "team_id": "T1234567890", + "user_id": "U1234567890", + } + + mock_response = Mock(spec=SlackResponse) + mock_response.data = mock_auth_data + self.mock_web_client.auth_test.return_value = mock_response + + # Call method + result = self.client.test_auth() + + # Verify API call + self.mock_web_client.auth_test.assert_called_once() + + # Verify result + assert result == mock_auth_data + + def test_test_auth_error(self): + """Test auth test with error.""" + # Mock error response + mock_response = Mock(spec=SlackResponse) + mock_response.data = {"ok": False, "error": "invalid_auth"} + self.mock_web_client.auth_test.return_value = mock_response + + # Should raise SlackApiError + with pytest.raises(SlackApiError): + self.client.test_auth() + + def test_get_team_info_success(self): + """Test successful team info retrieval.""" + # Mock team data + mock_team_data = { + "id": "T1234567890", + "name": "Example Team", + "domain": "example", + "email_domain": "example.com", + } + + mock_response = Mock(spec=SlackResponse) + mock_response.data = {"ok": True, "team": mock_team_data} + self.mock_web_client.team_info.return_value = mock_response + + # Call method + result = self.client.get_team_info() + + # Verify API call + self.mock_web_client.team_info.assert_called_once() + + # Verify result + assert result == mock_team_data + + def test_get_team_info_error(self): + """Test team info retrieval with error.""" + # Mock error response + mock_response = Mock(spec=SlackResponse) + mock_response.data = {"ok": False, "error": "access_denied"} + self.mock_web_client.team_info.return_value = mock_response + + # Should raise SlackApiError + with pytest.raises(SlackApiError): + self.client.get_team_info() + + def test_get_team_info_invalid_data(self): + """Test team info retrieval with invalid team data.""" + # Mock response with invalid team data + mock_response = Mock(spec=SlackResponse) + mock_response.data = {"ok": True, "team": "invalid_data"} + self.mock_web_client.team_info.return_value = mock_response + + # Should raise ValueError + with pytest.raises(ValueError, match="Invalid team data received"): + self.client.get_team_info() + + +class TestSlackAPIClientEdgeCases: + """Test edge cases and error handling for SlackAPIClient.""" + + def setup_method(self): + """Set up test environment.""" + self.mock_auth_manager = Mock() + self.mock_web_client = Mock() + self.mock_auth_manager.get_client.return_value = self.mock_web_client + + self.client = SlackAPIClient(self.mock_auth_manager) + + def test_list_channels_with_limit(self): + """Test list_channels respects limit parameter.""" + # Mock response with many channels + mock_channels = [ + { + "id": f"C{i:010d}", + "name": f"channel{i}", + "is_channel": True, + "is_member": True, + } + for i in range(50) + ] + + mock_response = Mock(spec=SlackResponse) + mock_response.data = { + "ok": True, + "channels": mock_channels, + "response_metadata": {"next_cursor": ""}, + } + self.mock_web_client.conversations_list.return_value = mock_response + + # Call with limit + result = self.client.list_channels(limit=10) + + # Should only return 10 channels + assert len(result) == 10 + + def test_get_conversation_history_with_all_params(self): + """Test conversation history with all parameters.""" + # Mock response + mock_response = Mock(spec=SlackResponse) + mock_response.data = { + "ok": True, + "messages": [], + "response_metadata": {"next_cursor": ""}, + } + self.mock_web_client.conversations_history.return_value = mock_response + + # Call with all parameters + self.client.get_conversation_history( + channel_id="C1234567890", + oldest="1234567880.000", + latest="1234567900.000", + inclusive=True, + limit=50, + cursor="cursor123", + ) + + # Verify all parameters passed correctly + self.mock_web_client.conversations_history.assert_called_once_with( + channel="C1234567890", + limit=50, + oldest="1234567880.000", + latest="1234567900.000", + cursor="cursor123", + inclusive=True, + ) + + def test_search_messages_invalid_response(self): + """Test search_messages with invalid response structure.""" + # Mock response with invalid messages section + mock_response = Mock(spec=SlackResponse) + mock_response.data = { + "ok": True, + "messages": "invalid_structure", # Should be dict + } + self.mock_web_client.search_messages.return_value = mock_response + + # Should handle gracefully and return empty results + messages, pagination = self.client.search_messages("test") + assert messages == [] + assert pagination == {} + + def test_large_member_list_pagination(self): + """Test handling of large member lists with pagination.""" + # Mock first page + mock_members_page1 = [f"U{i:010d}" for i in range(1000)] + mock_response_page1 = Mock(spec=SlackResponse) + mock_response_page1.data = { + "ok": True, + "members": mock_members_page1, + "response_metadata": {"next_cursor": "cursor123"}, + } + + # Mock second page + mock_members_page2 = [f"U{i:010d}" for i in range(1000, 1500)] + mock_response_page2 = Mock(spec=SlackResponse) + mock_response_page2.data = { + "ok": True, + "members": mock_members_page2, + "response_metadata": {"next_cursor": ""}, + } + + self.mock_web_client.conversations_members.side_effect = [ + mock_response_page1, + mock_response_page2, + ] + + # Call method + result = self.client.get_channel_members("C1234567890") + + # Should combine all pages + assert len(result) == 1500 + assert result[0] == "U0000000000" + assert result[-1] == "U0000001499" + + def test_api_rate_limit_handling(self): + """Test handling of API rate limit errors.""" + # Mock rate limit error + mock_response = Mock(spec=SlackResponse) + mock_response.data = {"ok": False, "error": "rate_limited"} + self.mock_web_client.conversations_info.return_value = mock_response + + # Should raise SlackApiError with rate limit info + with pytest.raises(SlackApiError): + self.client.get_channel_info("C1234567890") diff --git a/tests/services/test_slack_service.py b/tests/services/test_slack_service.py new file mode 100644 index 0000000..fe848ad --- /dev/null +++ b/tests/services/test_slack_service.py @@ -0,0 +1,1199 @@ +"""Tests for SlackService and SlackToolkit.""" + +from datetime import datetime, timezone +from unittest.mock import Mock + +import pytest +from pydantic import ValidationError + +from praga_core import ServerContext, clear_global_context, set_global_context +from praga_core.types import PageURI +from pragweb.slack import ( + SlackChannelListPage, + SlackChannelPage, + SlackConversationPage, + SlackMessagePage, + SlackMessageSummary, + SlackService, + SlackThreadPage, + SlackToolkit, + SlackUserPage, +) + + +class TestSlackService: + """Test suite for SlackService.""" + + def setup_method(self): + """Set up test environment.""" + # Clear any existing global context first + clear_global_context() + + # Create real ServerContext with in-memory SQLite PageCache + self.context = ServerContext(root="test-root", cache_url="sqlite:///:memory:") + + # Mock handler registration to avoid complexity + self.context.handler = Mock() + + # Make handler decorator work as a no-op + def mock_handler_decorator(page_type): + def decorator(func): + return func + + return decorator + + self.context.handler.side_effect = mock_handler_decorator + + set_global_context(self.context) + + # Create mock SlackAPIClient + self.mock_api_client = Mock() + self.mock_api_client.get_conversation_history = Mock() + self.mock_api_client.get_thread_replies = Mock() + self.mock_api_client.get_channel_info = Mock() + self.mock_api_client.get_channel_members = Mock() + self.mock_api_client.get_user_info = Mock() + self.mock_api_client.list_channels = Mock() + self.mock_api_client.get_team_info = Mock() + self.mock_api_client.search_messages = Mock() + + self.service = SlackService(self.mock_api_client) + + def teardown_method(self): + """Clean up test environment.""" + clear_global_context() + + def test_init(self): + """Test SlackService initialization.""" + assert self.service.api_client is self.mock_api_client + assert self.service.parser is not None + assert self.service.name == "slack" + assert "slack" in self.context.services + assert self.context.services["slack"] is self.service + + def test_create_conversation_page_success(self): + """Test successful conversation page creation.""" + # Mock conversation history response + mock_messages = [ + { + "ts": "1234567890.001", + "user": "U123456", + "text": "Hello everyone!", + }, + { + "ts": "1234567890.002", + "user": "U789012", + "text": "Hi there!", + }, + ] + + self.mock_api_client.get_conversation_history.return_value = ( + mock_messages, + None, + ) + + # Mock channel info + mock_channel_info = { + "id": "C1234567890", + "name": "general", + "is_channel": True, + "is_group": False, + "is_im": False, + "is_mpim": False, + } + self.mock_api_client.get_channel_info.return_value = mock_channel_info + self.mock_api_client.get_channel_members.return_value = ["U123456", "U789012"] + + # Mock user info + def mock_user_info(user_id): + users = { + "U123456": { + "id": "U123456", + "name": "alice", + "real_name": "Alice Smith", + "profile": {"display_name": "Alice"}, + }, + "U789012": { + "id": "U789012", + "name": "bob", + "real_name": "Bob Jones", + "profile": {"display_name": "Bob"}, + }, + } + return users.get(user_id, {}) + + self.mock_api_client.get_user_info.side_effect = mock_user_info + + # Real page cache is now being used - no mocking needed + + # Call create_conversation_page + result = self.service.create_conversation_page("C1234567890") + + # Verify API calls + self.mock_api_client.get_conversation_history.assert_called_once() + self.mock_api_client.get_channel_info.assert_called_once_with("C1234567890") + + # Verify result + assert isinstance(result, SlackConversationPage) + assert result.conversation_id == "C1234567890" + assert result.channel_id == "C1234567890" + assert result.channel_name == "general" + assert result.channel_type == "public_channel" + assert result.message_count == 2 + assert "Alice" in result.participants + assert "Bob" in result.participants + assert "Hello everyone!" in result.messages_content + assert "Hi there!" in result.messages_content + + # Verify URI + expected_uri = PageURI( + root="test-root", type="slack_conversation", id="C1234567890", version=1 + ) + assert result.uri == expected_uri + + def test_create_conversation_page_no_messages(self): + """Test create_conversation_page with no messages.""" + self.mock_api_client.get_conversation_history.return_value = ([], None) + + # Mock channel info (needed for get_channel_page call) + mock_channel_info = { + "id": "C1234567890", + "name": "general", + "is_channel": True, + "is_group": False, + "is_im": False, + "is_mpim": False, + "created": 1234567890, + } + self.mock_api_client.get_channel_info.return_value = mock_channel_info + self.mock_api_client.get_channel_members.return_value = [] + + with pytest.raises(ValueError, match="No messages found in channel"): + self.service.create_conversation_page("C1234567890") + + def test_create_message_page_success(self): + """Test successful message page creation.""" + # Mock message data + mock_messages = [ + { + "ts": "1234567890.001", + "user": "U123456", + "text": "This is a test message", + "thread_ts": None, + } + ] + self.mock_api_client.get_conversation_history.return_value = ( + mock_messages, + None, + ) + + # Mock channel info + mock_channel_info = { + "id": "C1234567890", + "name": "general", + "is_channel": True, + "is_group": False, + "is_im": False, + "is_mpim": False, + } + self.mock_api_client.get_channel_info.return_value = mock_channel_info + self.mock_api_client.get_channel_members.return_value = ["U123456"] + + # Mock user info + self.mock_api_client.get_user_info.return_value = { + "id": "U123456", + "name": "alice", + "real_name": "Alice Smith", + "profile": {"display_name": "Alice"}, + } + + # Real page cache is now being used - no mocking needed + + # Create message page + message_id = "C1234567890_1234567890.001" + result = self.service.create_message_page(message_id) + + # Verify result + assert isinstance(result, SlackMessagePage) + assert result.message_ts == "1234567890.001" + assert result.channel_id == "C1234567890" + assert result.channel_name == "general" + assert result.display_name == "Alice" + assert result.text_content == "This is a test message" + assert result.thread_ts is None + + # Verify URI + expected_uri = PageURI( + root="test-root", type="slack_message", id=message_id, version=1 + ) + assert result.uri == expected_uri + + def test_create_message_page_not_found(self): + """Test create_message_page when message not found.""" + self.mock_api_client.get_conversation_history.return_value = ([], None) + + # Mock channel info (needed for get_channel_page call) + mock_channel_info = { + "id": "C1234567890", + "name": "general", + "is_channel": True, + "is_group": False, + "is_im": False, + "is_mpim": False, + "created": 1234567890, + } + self.mock_api_client.get_channel_info.return_value = mock_channel_info + self.mock_api_client.get_channel_members.return_value = [] + + with pytest.raises(RuntimeError, match="Unable to find message"): + self.service.create_message_page("C1234567890_1234567890.001") + + def test_create_thread_page_success(self): + """Test successful thread page creation.""" + # Mock thread messages + mock_messages = [ + { + "ts": "1234567890.001", + "user": "U123456", + "text": "This is the parent message", + }, + { + "ts": "1234567890.002", + "user": "U789012", + "text": "This is a reply", + }, + ] + self.mock_api_client.get_thread_replies.return_value = mock_messages + + # Mock channel info + mock_channel_info = { + "id": "C1234567890", + "name": "general", + "is_channel": True, + "is_group": False, + "is_im": False, + "is_mpim": False, + } + self.mock_api_client.get_channel_info.return_value = mock_channel_info + self.mock_api_client.get_channel_members.return_value = ["U123456", "U789012"] + + # Mock user info + def mock_user_info(user_id): + users = { + "U123456": { + "id": "U123456", + "name": "alice", + "profile": {"display_name": "Alice"}, + }, + "U789012": { + "id": "U789012", + "name": "bob", + "profile": {"display_name": "Bob"}, + }, + } + return users.get(user_id, {}) + + self.mock_api_client.get_user_info.side_effect = mock_user_info + + # Real page cache is now being used - no mocking needed + + # Create thread page + thread_id = "C1234567890_1234567890.001" + result = self.service.create_thread_page(thread_id) + + # Verify result + assert isinstance(result, SlackThreadPage) + assert result.thread_ts == "1234567890.001" + assert result.channel_id == "C1234567890" + assert result.channel_name == "general" + assert result.parent_message == "This is the parent message" + assert result.message_count == 2 + assert len(result.messages) == 2 + assert "Alice" in result.participants + assert "Bob" in result.participants + + # Verify messages + assert result.messages[0].display_name == "Alice" + assert result.messages[0].text == "This is the parent message" + assert result.messages[1].display_name == "Bob" + assert result.messages[1].text == "This is a reply" + + def test_create_thread_page_no_messages(self): + """Test create_thread_page with no messages.""" + self.mock_api_client.get_thread_replies.return_value = [] + + with pytest.raises(ValueError, match="Thread .* contains no messages"): + self.service.create_thread_page("C1234567890_1234567890.001") + + def test_create_channel_page_success(self): + """Test successful channel page creation.""" + # Mock channel info + mock_channel_info = { + "id": "C1234567890", + "name": "general", + "is_channel": True, + "is_group": False, + "is_im": False, + "is_mpim": False, + "topic": {"value": "General discussion"}, + "purpose": {"value": "Company-wide announcements"}, + "created": 1234567890, + "is_archived": False, + } + self.mock_api_client.get_channel_info.return_value = mock_channel_info + self.mock_api_client.get_channel_members.return_value = ["U123456", "U789012"] + + # Create channel page + result = self.service.create_channel_page("C1234567890") + + # Verify result + assert isinstance(result, SlackChannelPage) + assert result.channel_id == "C1234567890" + assert result.name == "general" + assert result.channel_type == "public_channel" + assert result.topic == "General discussion" + assert result.purpose == "Company-wide announcements" + assert result.member_count == 2 + assert not result.is_archived + + # Verify URI + expected_uri = PageURI( + root="test-root", type="slack_channel", id="C1234567890", version=1 + ) + assert result.uri == expected_uri + + def test_create_user_page_success(self): + """Test successful user page creation.""" + # Mock user info + mock_user_info = { + "id": "U123456", + "name": "alice", + "real_name": "Alice Smith", + "is_bot": False, + "is_admin": True, + "profile": { + "display_name": "Alice", + "email": "alice@example.com", + "title": "Software Engineer", + "status_text": "Working on tests", + "status_emoji": ":computer:", + }, + } + self.mock_api_client.get_user_info.return_value = mock_user_info + + # Create user page + result = self.service.create_user_page("U123456") + + # Verify result + assert isinstance(result, SlackUserPage) + assert result.user_id == "U123456" + assert result.name == "alice" + assert result.real_name == "Alice Smith" + assert result.display_name == "Alice" + assert result.email == "alice@example.com" + assert result.title == "Software Engineer" + assert not result.is_bot + assert result.is_admin + assert result.status_text == "Working on tests" + assert result.status_emoji == ":computer:" + + # Verify URI + expected_uri = PageURI( + root="test-root", type="slack_user", id="U123456", version=1 + ) + assert result.uri == expected_uri + + def test_get_user_display_name(self): + """Test get_user_display_name method.""" + # Mock user info with display name + mock_user_info = { + "id": "U123456", + "name": "alice", + "real_name": "Alice Smith", + "profile": {"display_name": "Alice"}, + } + self.mock_api_client.get_user_info.return_value = mock_user_info + + # Real page cache is now being used - no mocking needed + + result = self.service.get_user_display_name("U123456") + assert result == "Alice" + + # Test with empty user ID + result = self.service.get_user_display_name("") + assert result == "unknown" + + def test_search_messages(self): + """Test search_messages method.""" + # Mock search response + mock_messages = [ + { + "ts": "1234567890.001", + "channel": {"id": "C1234567890"}, + "user": "U123456", + "text": "Test message 1", + }, + { + "ts": "1234567890.002", + "channel": {"id": "C1234567890"}, + "user": "U789012", + "text": "Test message 2", + }, + ] + mock_pagination = {"page": 1, "page_count": 1, "total": 2} + + self.mock_api_client.search_messages.return_value = ( + mock_messages, + mock_pagination, + ) + + # Call search_messages + uris, next_token = self.service.search_messages("test query") + + # Verify API call + self.mock_api_client.search_messages.assert_called_once() + + # Verify results + assert len(uris) == 2 + assert all(isinstance(uri, PageURI) for uri in uris) + assert all(uri.type == "slack_message" for uri in uris) + assert all(uri.root == "test-root" for uri in uris) + + def test_name_property(self): + """Test name property.""" + assert self.service.name == "slack" + + +class TestSlackPageTypes: + """Test page type serialization and validation.""" + + def setup_method(self): + """Set up test data.""" + self.test_uri = PageURI( + root="test", type="slack_message", id="test123", version=1 + ) + self.test_time = datetime(2023, 6, 15, 10, 30, 0, tzinfo=timezone.utc) + + def test_slack_message_summary_serialization(self): + """Test SlackMessageSummary serialization.""" + summary = SlackMessageSummary( + display_name="Alice", + text="Test message", + timestamp=self.test_time, + ) + + # Test serialization + data = summary.model_dump() + assert data["display_name"] == "Alice" + assert data["text"] == "Test message" + assert isinstance(data["timestamp"], datetime) + + # Test deserialization + new_summary = SlackMessageSummary.model_validate(data) + assert new_summary.display_name == "Alice" + assert new_summary.text == "Test message" + assert new_summary.timestamp == self.test_time + + def test_slack_message_summary_required_fields(self): + """Test SlackMessageSummary with missing required fields.""" + with pytest.raises(ValidationError): + SlackMessageSummary() + + with pytest.raises(ValidationError): + SlackMessageSummary(display_name="Alice") + + def test_slack_conversation_page_serialization(self): + """Test SlackConversationPage serialization.""" + # Create proper URI for slack_conversation page type + conversation_uri = PageURI( + root="test", type="slack_conversation", id="conv123", version=1 + ) + + page = SlackConversationPage( + uri=conversation_uri, + conversation_id="conv123", + channel_id="C1234567890", + channel_name="general", + channel_type="public_channel", + start_time=self.test_time, + end_time=self.test_time, + message_count=5, + participants=["Alice", "Bob"], + messages_content="Alice: Hello\nBob: Hi", + permalink="https://slack.com/app_redirect?channel=C1234567890", + ) + + # Test serialization + data = page.model_dump() + assert data["channel_name"] == "general" + assert data["channel_type"] == "public_channel" + assert data["message_count"] == 5 + assert data["participants"] == ["Alice", "Bob"] + assert "conversation_id" not in data # Should be excluded + + # Test deserialization + data["uri"] = conversation_uri # Add URI back for validation + data["conversation_id"] = "conv123" # Add excluded fields back for validation + data["channel_id"] = "C1234567890" + new_page = SlackConversationPage.model_validate(data) + assert new_page.channel_name == "general" + assert new_page.message_count == 5 + + def test_slack_thread_page_serialization(self): + """Test SlackThreadPage serialization.""" + message1 = SlackMessageSummary( + display_name="Alice", + text="Parent message", + timestamp=self.test_time, + ) + message2 = SlackMessageSummary( + display_name="Bob", + text="Reply message", + timestamp=self.test_time, + ) + + # Create proper URI for slack_thread page type + thread_uri = PageURI( + root="test", type="slack_thread", id="thread123", version=1 + ) + + page = SlackThreadPage( + uri=thread_uri, + thread_ts="1234567890.001", + channel_id="C1234567890", + channel_name="general", + parent_message="Parent message", + messages=[message1, message2], + message_count=2, + participants=["Alice", "Bob"], + created_at=self.test_time, + last_reply_at=self.test_time, + permalink="https://slack.com/app_redirect?channel=C1234567890&message_ts=1234567890.001", + ) + + # Test serialization + data = page.model_dump() + assert data["channel_name"] == "general" + assert data["parent_message"] == "Parent message" + assert data["message_count"] == 2 + assert len(data["messages"]) == 2 + assert "thread_ts" not in data # Should be excluded + + # Test deserialization + data["uri"] = thread_uri # Add URI back for validation + data["thread_ts"] = "1234567890.001" # Add excluded fields back for validation + data["channel_id"] = "C1234567890" + new_page = SlackThreadPage.model_validate(data) + assert new_page.channel_name == "general" + assert new_page.message_count == 2 + assert len(new_page.messages) == 2 + + def test_slack_thread_page_methods(self): + """Test SlackThreadPage helper methods.""" + message1 = SlackMessageSummary( + display_name="Alice", + text="Hello everyone", + timestamp=self.test_time, + ) + message2 = SlackMessageSummary( + display_name="Bob", + text="Hi there!", + timestamp=self.test_time, + ) + + page = SlackThreadPage( + uri=self.test_uri, + thread_ts="1234567890.001", + channel_id="C1234567890", + channel_name="general", + parent_message="Hello everyone, let's discuss this", + messages=[message1, message2], + message_count=2, + participants=["Alice", "Bob"], + created_at=self.test_time, + last_reply_at=self.test_time, + permalink="https://slack.com/app_redirect?channel=C1234567890&message_ts=1234567890.001", + ) + + # Test thread_messages property + thread_messages = page.thread_messages + assert "Alice: Hello everyone" in thread_messages + assert "Bob: Hi there!" in thread_messages + + def test_slack_thread_page_long_parent_message(self): + """Test SlackThreadPage with long parent message.""" + thread_uri = PageURI( + root="test", type="slack_thread", id="thread123", version=1 + ) + + long_message = "This is a very long message that should be stored properly in the thread page" + page = SlackThreadPage( + uri=thread_uri, + thread_ts="1234567890.001", + channel_id="C1234567890", + channel_name="general", + parent_message=long_message, + messages=[], + message_count=0, + participants=[], + created_at=self.test_time, + last_reply_at=None, + permalink="https://slack.com/app_redirect?channel=C1234567890&message_ts=1234567890.001", + ) + + assert page.parent_message == long_message + assert page.message_count == 0 + assert len(page.messages) == 0 + + def test_slack_channel_page_serialization(self): + """Test SlackChannelPage serialization.""" + channel_uri = PageURI( + root="test", type="slack_channel", id="C1234567890", version=1 + ) + + page = SlackChannelPage( + uri=channel_uri, + channel_id="C1234567890", + name="general", + channel_type="public_channel", + topic="General discussion", + purpose="Company-wide announcements", + member_count=100, + created=self.test_time, + is_archived=False, + last_activity=self.test_time, + permalink="https://slack.com/app_redirect?channel=C1234567890", + ) + + # Test serialization + data = page.model_dump() + assert data["name"] == "general" + assert data["channel_type"] == "public_channel" + assert data["member_count"] == 100 + assert not data["is_archived"] + assert "channel_id" not in data # Should be excluded + + # Test deserialization + data["uri"] = channel_uri # Add URI back for validation + data["channel_id"] = "C1234567890" # Add excluded fields back for validation + new_page = SlackChannelPage.model_validate(data) + assert new_page.name == "general" + assert new_page.member_count == 100 + + def test_slack_user_page_serialization(self): + """Test SlackUserPage serialization.""" + user_uri = PageURI(root="test", type="slack_user", id="U123456", version=1) + + page = SlackUserPage( + uri=user_uri, + user_id="U123456", + name="alice", + real_name="Alice Smith", + display_name="Alice", + email="alice@example.com", + title="Software Engineer", + is_bot=False, + is_admin=True, + status_text="Working", + status_emoji=":computer:", + last_updated=self.test_time, + ) + + # Test serialization + data = page.model_dump() + assert data["name"] == "alice" + assert data["real_name"] == "Alice Smith" + assert data["email"] == "alice@example.com" + assert not data["is_bot"] + assert data["is_admin"] + assert "user_id" not in data # Should be excluded + + # Test deserialization + data["uri"] = user_uri # Add URI back for validation + data["user_id"] = "U123456" # Add excluded fields back for validation + new_page = SlackUserPage.model_validate(data) + assert new_page.name == "alice" + assert new_page.is_admin + + def test_slack_message_page_serialization(self): + """Test SlackMessagePage serialization.""" + message_uri = PageURI( + root="test", + type="slack_message", + id="C1234567890_1234567890.001", + version=1, + ) + + page = SlackMessagePage( + uri=message_uri, + message_ts="1234567890.001", + channel_id="C1234567890", + channel_name="general", + channel_type="public_channel", + user_id="U123456", + display_name="Alice", + text_content="Test message", + timestamp=self.test_time, + thread_ts="1234567890.001", + permalink="https://slack.com/app_redirect?channel=C1234567890&message_ts=1234567890.001", + ) + + # Test serialization + data = page.model_dump() + assert data["channel_name"] == "general" + assert data["display_name"] == "Alice" + assert data["text_content"] == "Test message" + assert "message_ts" not in data # Should be excluded + assert "channel_id" not in data # Should be excluded + assert "user_id" not in data # Should be excluded + + # Test deserialization + data["uri"] = message_uri # Add URI back for validation + data["message_ts"] = "1234567890.001" # Add excluded fields back for validation + data["channel_id"] = "C1234567890" + data["user_id"] = "U123456" + new_page = SlackMessagePage.model_validate(data) + assert new_page.channel_name == "general" + assert new_page.display_name == "Alice" + + def test_slack_message_page_computed_fields(self): + """Test SlackMessagePage computed field properties.""" + page = SlackMessagePage( + uri=PageURI( + root="test", + type="slack_message", + id="C1234567890_1234567890.001", + version=1, + ), + message_ts="1234567890.001", + channel_id="C1234567890", + channel_name="general", + channel_type="public_channel", + user_id="U123456", + display_name="Alice", + text_content="Test message", + timestamp=self.test_time, + thread_ts="1234567890.001", + permalink="https://slack.com/app_redirect?channel=C1234567890&message_ts=1234567890.001", + ) + + # Test thread_uri property + thread_uri = page.thread_uri + assert isinstance(thread_uri, PageURI) + assert thread_uri.root == "test" + assert thread_uri.type == "slack_thread" + assert thread_uri.id == "C1234567890_1234567890.001" + assert thread_uri.version == 1 + + # Test with no thread_ts + page_no_thread = SlackMessagePage( + uri=self.test_uri, + message_ts="1234567890.001", + channel_id="C1234567890", + channel_name="general", + channel_type="public_channel", + user_id="U123456", + display_name="Alice", + text_content="Test message", + timestamp=self.test_time, + thread_ts=None, + permalink="https://slack.com/app_redirect?channel=C1234567890&message_ts=1234567890.001", + ) + + assert page_no_thread.thread_uri is None + + def test_slack_channel_list_page_serialization(self): + """Test SlackChannelListPage serialization.""" + channels_data = [ + { + "id": "C1234567890", + "name": "general", + "is_channel": True, + "member_count": 100, + }, + { + "id": "C0987654321", + "name": "random", + "is_channel": True, + "member_count": 50, + }, + ] + + channel_list_uri = PageURI( + root="test", type="slack_channel_list", id="T1234567890", version=1 + ) + + page = SlackChannelListPage( + uri=channel_list_uri, + workspace_id="T1234567890", + workspace_name="Test Workspace", + total_channels=2, + public_channels=2, + private_channels=0, + channels=channels_data, + last_updated=self.test_time, + ) + + # Test serialization + data = page.model_dump() + assert data["workspace_name"] == "Test Workspace" + assert data["total_channels"] == 2 + assert data["public_channels"] == 2 + assert len(data["channels"]) == 2 + assert "workspace_id" not in data # Should be excluded + + # Test deserialization + data["uri"] = channel_list_uri # Add URI back for validation + data["workspace_id"] = "T1234567890" # Add excluded fields back for validation + new_page = SlackChannelListPage.model_validate(data) + assert new_page.workspace_name == "Test Workspace" + assert new_page.total_channels == 2 + + +class TestSlackToolkit: + """Test SlackToolkit functionality.""" + + def setup_method(self): + """Set up test environment.""" + clear_global_context() + + # Create real ServerContext with in-memory SQLite PageCache + self.context = ServerContext(root="test-root", cache_url="sqlite:///:memory:") + + # Mock handler registration to avoid complexity + self.context.handler = Mock() + + def mock_handler_decorator(page_type): + def decorator(func): + return func + + return decorator + + self.context.handler.side_effect = mock_handler_decorator + + set_global_context(self.context) + + # Create mock SlackAPIClient and service + self.mock_api_client = Mock() + self.slack_service = SlackService(self.mock_api_client) + self.toolkit = SlackToolkit(self.slack_service) + + def _create_mock_message_page(self, msg_id="msg1"): + """Helper to create a mock SlackMessagePage.""" + mock_uri = PageURI(root="test-root", type="slack_message", id=msg_id, version=1) + return SlackMessagePage( + uri=mock_uri, + message_ts="1234567890.001", + channel_id="C1234567890", + channel_name="general", + channel_type="public_channel", + user_id="U123456", + display_name="Alice", + text_content="Test message", + timestamp=datetime(2023, 6, 15, 10, 30, 0, tzinfo=timezone.utc), + thread_ts=None, # Add missing required field + permalink="https://slack.com/app_redirect?channel=C1234567890&message_ts=1234567890.001", + ) + + def teardown_method(self): + """Clean up test environment.""" + clear_global_context() + + def test_toolkit_init(self): + """Test SlackToolkit initialization.""" + assert self.toolkit.slack_service is self.slack_service + assert self.toolkit.name == "slack" + + def test_search_messages_by_content(self): + """Test search_messages_by_content tool.""" + # Mock search results with actual message pages + mock_uri1 = PageURI( + root="test-root", type="slack_message", id="msg1", version=1 + ) + mock_uri2 = PageURI( + root="test-root", type="slack_message", id="msg2", version=1 + ) + + mock_page1 = SlackMessagePage( + uri=mock_uri1, + message_ts="1234567890.001", + channel_id="C1234567890", + channel_name="general", + channel_type="public_channel", + user_id="U123456", + display_name="Alice", + text_content="Test message 1", + timestamp=datetime(2023, 6, 15, 10, 30, 0, tzinfo=timezone.utc), + thread_ts=None, # Add missing required field + permalink="https://slack.com/app_redirect?channel=C1234567890&message_ts=1234567890.001", + ) + + mock_page2 = SlackMessagePage( + uri=mock_uri2, + message_ts="1234567890.002", + channel_id="C1234567890", + channel_name="general", + channel_type="public_channel", + user_id="U789012", + display_name="Bob", + text_content="Test message 2", + timestamp=datetime(2023, 6, 15, 10, 31, 0, tzinfo=timezone.utc), + thread_ts=None, # Add missing required field + permalink="https://slack.com/app_redirect?channel=C1234567890&message_ts=1234567890.002", + ) + + self.slack_service.search_messages = Mock( + return_value=([mock_uri1, mock_uri2], "next_token") + ) + + # Mock the page resolution - use the global context + def mock_get_page(uri): + if uri.id == "msg1": + return mock_page1 + elif uri.id == "msg2": + return mock_page2 + return Mock() + + self.context.get_page = Mock(side_effect=mock_get_page) + + # Call tool + result = self.toolkit.search_messages_by_content("test query") + + # Verify result + assert len(result.results) == 2 + assert result.next_cursor == "next_token" + assert all(isinstance(page, SlackMessagePage) for page in result.results) + + def test_search_messages_by_channel(self): + """Test search_messages_by_channel tool.""" + # Create mock message page + mock_uri = PageURI(root="test-root", type="slack_message", id="msg1", version=1) + mock_page = SlackMessagePage( + uri=mock_uri, + message_ts="1234567890.001", + channel_id="C1234567890", + channel_name="general", + channel_type="public_channel", + user_id="U123456", + display_name="Alice", + text_content="Test message", + timestamp=datetime(2023, 6, 15, 10, 30, 0, tzinfo=timezone.utc), + thread_ts=None, + permalink="https://slack.com/app_redirect?channel=C1234567890&message_ts=1234567890.001", + ) + + self.slack_service.search_messages = Mock(return_value=([mock_uri], None)) + self.context.get_page = Mock(return_value=mock_page) + + self.toolkit.search_messages_by_channel("general") + + # Verify search query includes channel filter + self.slack_service.search_messages.assert_called_once() + search_query = self.slack_service.search_messages.call_args[0][0] + assert "in:#general" in search_query + + def test_search_messages_by_person(self): + """Test search_messages_by_person tool.""" + mock_page = self._create_mock_message_page() + mock_uri = mock_page.uri + + self.slack_service.search_messages = Mock(return_value=([mock_uri], None)) + self.context.get_page = Mock(return_value=mock_page) + + self.toolkit.search_messages_by_person("@alice") + + # Verify search query includes person filter + self.slack_service.search_messages.assert_called_once() + search_query = self.slack_service.search_messages.call_args[0][0] + assert "from:@alice" in search_query + + def test_search_messages_by_date_range(self): + """Test search_messages_by_date_range tool.""" + mock_page = self._create_mock_message_page() + mock_uri = mock_page.uri + + self.slack_service.search_messages = Mock(return_value=([mock_uri], None)) + self.context.get_page = Mock(return_value=mock_page) + + self.toolkit.search_messages_by_date_range("2023-06-15", 7) + + # Verify search query includes date filters + self.slack_service.search_messages.assert_called_once() + search_query = self.slack_service.search_messages.call_args[0][0] + assert "after:2023-06-15" in search_query + assert "before:2023-06-22" in search_query + + def test_search_recent_messages(self): + """Test search_recent_messages tool.""" + mock_page = self._create_mock_message_page() + mock_uri = mock_page.uri + + self.slack_service.search_messages = Mock(return_value=([mock_uri], None)) + self.context.get_page = Mock(return_value=mock_page) + + self.toolkit.search_recent_messages(days=3) + + # Verify search was called + self.slack_service.search_messages.assert_called_once() + + def test_search_direct_messages(self): + """Test search_direct_messages tool.""" + mock_page = self._create_mock_message_page() + mock_uri = mock_page.uri + + self.slack_service.search_messages = Mock(return_value=([mock_uri], None)) + self.context.get_page = Mock(return_value=mock_page) + + self.toolkit.search_direct_messages(person="@alice") + + # Verify search query includes DM filters + self.slack_service.search_messages.assert_called_once() + search_query = self.slack_service.search_messages.call_args[0][0] + assert "in:@alice" in search_query + + def test_get_conversation_with_person(self): + """Test get_conversation_with_person tool.""" + mock_page = self._create_mock_message_page() + mock_uri = mock_page.uri + + self.slack_service.search_messages = Mock(return_value=([mock_uri], None)) + self.context.get_page = Mock(return_value=mock_page) + + self.toolkit.get_conversation_with_person("@alice") + + # Verify search query + self.slack_service.search_messages.assert_called_once() + search_query = self.slack_service.search_messages.call_args[0][0] + assert ( + search_query == "@alice" + ) # The method just passes the validated person identifier + + +class TestSlackServiceIntegration: + """Integration tests for SlackService components.""" + + def setup_method(self): + """Set up test environment.""" + clear_global_context() + + self.context = ServerContext(root="test-root", cache_url="sqlite:///:memory:") + self.context.handler = Mock() + + def mock_handler_decorator(page_type): + def decorator(func): + return func + + return decorator + + self.context.handler.side_effect = mock_handler_decorator + + set_global_context(self.context) + + self.mock_api_client = Mock() + self.service = SlackService(self.mock_api_client) + + def teardown_method(self): + """Clean up test environment.""" + clear_global_context() + + def test_message_page_thread_uri_matches_thread_page_uri(self): + """Test that message page thread_uri matches thread page URI.""" + # Setup message with thread + mock_messages = [ + { + "ts": "1234567890.001", + "user": "U123456", + "text": "Test message in thread", + "thread_ts": "1234567890.001", + } + ] + self.mock_api_client.get_conversation_history.return_value = ( + mock_messages, + None, + ) + + # Mock channel and user info + mock_channel_info = { + "id": "C1234567890", + "name": "general", + "is_channel": True, + "is_group": False, + "is_im": False, + "is_mpim": False, + } + self.mock_api_client.get_channel_info.return_value = mock_channel_info + self.mock_api_client.get_channel_members.return_value = ["U123456"] + + mock_user_info = { + "id": "U123456", + "name": "alice", + "profile": {"display_name": "Alice"}, + } + self.mock_api_client.get_user_info.return_value = mock_user_info + + # Mock thread replies + self.mock_api_client.get_thread_replies.return_value = mock_messages + + # Real page cache is now being used - no mocking needed + + # Create message page + message_id = "C1234567890_1234567890.001" + message_page = self.service.create_message_page(message_id) + + # Create thread page + thread_id = "C1234567890_1234567890.001" + thread_page = self.service.create_thread_page(thread_id) + + # Verify URIs match + assert message_page.thread_uri == thread_page.uri + + def test_channel_page_caching(self): + """Test that channel pages are properly cached.""" + # Mock channel info + mock_channel_info = { + "id": "C1234567890", + "name": "general", + "is_channel": True, + "is_group": False, + "is_im": False, + "is_mpim": False, + } + self.mock_api_client.get_channel_info.return_value = mock_channel_info + self.mock_api_client.get_channel_members.return_value = [] + + # Real page cache is now being used - test actual caching behavior + + # First call should create the page + page1 = self.service.get_channel_page("C1234567890") + + # Second call should return cached page + page2 = self.service.get_channel_page("C1234567890") + + assert page1.uri == page2.uri + assert page1.name == page2.name + assert page1.channel_type == page2.channel_type + + # API should only be called once + self.mock_api_client.get_channel_info.assert_called_once() + + def test_user_page_caching(self): + """Test that user pages are properly cached.""" + # Mock user info + mock_user_info = { + "id": "U123456", + "name": "alice", + "profile": {"display_name": "Alice"}, + } + self.mock_api_client.get_user_info.return_value = mock_user_info + + # First call should create the page + page1 = self.service.get_user_page("U123456") + + # Second call should return cached page + page2 = self.service.get_user_page("U123456") + + assert page1.uri == page2.uri + assert page1.name == page2.name + assert page1.user_id == page2.user_id + + self.mock_api_client.get_user_info.assert_called_once() diff --git a/tests/services/test_slack_utils.py b/tests/services/test_slack_utils.py new file mode 100644 index 0000000..f08a8c3 --- /dev/null +++ b/tests/services/test_slack_utils.py @@ -0,0 +1,566 @@ +"""Tests for SlackParser utility class.""" + +import pytest + +from pragweb.slack.utils import SlackParser + + +class TestSlackParser: + """Test suite for SlackParser utility class.""" + + def setup_method(self): + """Set up test environment.""" + self.parser = SlackParser() + + def test_encode_message_id(self): + """Test message ID encoding.""" + channel_id = "C1234567890" + message_ts = "1234567890.001" + + result = SlackParser.encode_message_id(channel_id, message_ts) + expected = "C1234567890_1234567890.001" + + assert result == expected + + def test_encode_message_id_with_underscores(self): + """Test message ID encoding with channel that contains underscores.""" + channel_id = "C_TEST_CHANNEL" + message_ts = "1234567890.001" + + result = SlackParser.encode_message_id(channel_id, message_ts) + expected = "C_TEST_CHANNEL_1234567890.001" + + assert result == expected + + def test_decode_message_id(self): + """Test message ID decoding.""" + message_id = "C1234567890_1234567890.001" + + channel_id, message_ts = SlackParser.decode_message_id(message_id) + + assert channel_id == "C1234567890" + assert message_ts == "1234567890.001" + + def test_decode_message_id_with_underscores(self): + """Test message ID decoding with channel that contains underscores.""" + message_id = "C_TEST_CHANNEL_1234567890.001" + + channel_id, message_ts = SlackParser.decode_message_id(message_id) + + assert channel_id == "C_TEST_CHANNEL" + assert message_ts == "1234567890.001" + + def test_decode_message_id_multiple_underscores(self): + """Test message ID decoding with multiple underscores in channel ID.""" + message_id = "C_VERY_LONG_CHANNEL_NAME_1234567890.001" + + channel_id, message_ts = SlackParser.decode_message_id(message_id) + + assert channel_id == "C_VERY_LONG_CHANNEL_NAME" + assert message_ts == "1234567890.001" + + def test_decode_message_id_invalid_format(self): + """Test message ID decoding with invalid format.""" + # No underscore + with pytest.raises(ValueError, match="Invalid message ID format"): + SlackParser.decode_message_id("C1234567890") + + # Empty string + with pytest.raises(ValueError, match="Invalid message ID format"): + SlackParser.decode_message_id("") + + def test_encode_thread_id(self): + """Test thread ID encoding.""" + channel_id = "C1234567890" + thread_ts = "1234567890.001" + + result = SlackParser.encode_thread_id(channel_id, thread_ts) + expected = "C1234567890_1234567890.001" + + assert result == expected + + def test_decode_thread_id(self): + """Test thread ID decoding.""" + thread_id = "C1234567890_1234567890.001" + + channel_id, thread_ts = SlackParser.decode_thread_id(thread_id) + + assert channel_id == "C1234567890" + assert thread_ts == "1234567890.001" + + def test_decode_thread_id_invalid_format(self): + """Test thread ID decoding with invalid format.""" + with pytest.raises(ValueError, match="Invalid thread ID format"): + SlackParser.decode_thread_id("invalidformat") + + def test_determine_channel_type_public_channel(self): + """Test channel type determination for public channel.""" + channel_info = { + "is_channel": True, + "is_group": False, + "is_im": False, + "is_mpim": False, + } + + result = SlackParser.determine_channel_type(channel_info) + assert result == "public_channel" + + def test_determine_channel_type_private_channel(self): + """Test channel type determination for private channel/group.""" + channel_info = { + "is_channel": False, + "is_group": True, + "is_im": False, + "is_mpim": False, + } + + result = SlackParser.determine_channel_type(channel_info) + assert result == "private_channel" + + def test_determine_channel_type_direct_message(self): + """Test channel type determination for direct message.""" + channel_info = { + "is_channel": False, + "is_group": False, + "is_im": True, + "is_mpim": False, + } + + result = SlackParser.determine_channel_type(channel_info) + assert result == "im" + + def test_determine_channel_type_group_dm(self): + """Test channel type determination for group DM.""" + channel_info = { + "is_channel": False, + "is_group": False, + "is_im": False, + "is_mpim": True, + } + + result = SlackParser.determine_channel_type(channel_info) + assert result == "mpim" + + def test_determine_channel_type_unknown(self): + """Test channel type determination for unknown type.""" + channel_info = { + "is_channel": False, + "is_group": False, + "is_im": False, + "is_mpim": False, + } + + result = SlackParser.determine_channel_type(channel_info) + assert result == "unknown" + + def test_get_user_display_name_with_display_name(self): + """Test user display name extraction with display_name.""" + user_info = { + "id": "U123456", + "name": "alice", + "real_name": "Alice Smith", + "profile": {"display_name": "Alice"}, + } + + result = SlackParser.get_user_display_name(user_info) + assert result == "Alice" + + def test_get_user_display_name_with_real_name(self): + """Test user display name extraction with real_name fallback.""" + user_info = { + "id": "U123456", + "name": "alice", + "real_name": "Alice Smith", + "profile": {}, # No display name + } + + result = SlackParser.get_user_display_name(user_info) + assert result == "Alice Smith" + + def test_get_user_display_name_with_username(self): + """Test user display name extraction with username fallback.""" + user_info = { + "id": "U123456", + "name": "alice", + # No real_name or display_name + } + + result = SlackParser.get_user_display_name(user_info) + assert result == "alice" + + def test_get_user_display_name_with_id_fallback(self): + """Test user display name extraction with ID fallback.""" + user_info = { + "id": "U123456", + # No name, real_name, or display_name + } + + result = SlackParser.get_user_display_name(user_info) + assert result == "U123456" + + def test_get_user_display_name_minimal(self): + """Test user display name extraction with minimal data.""" + user_info = {} + + result = SlackParser.get_user_display_name(user_info) + assert result == "unknown" + + def test_format_messages_for_content(self): + """Test message formatting for content.""" + messages = [ + { + "ts": "1234567890.001", + "user": "U123456", + "text": "Hello everyone!", + }, + { + "ts": "1234567890.002", + "user": "U789012", + "text": "Hi there!", + }, + { + "ts": "1234567890.003", + "user": "U456789", + "text": "How's everyone doing?", + }, + ] + + def mock_get_display_name(user_id: str) -> str: + names = { + "U123456": "Alice", + "U789012": "Bob", + "U456789": "Charlie", + } + return names.get(user_id, "Unknown") + + result = SlackParser.format_messages_for_content( + messages, mock_get_display_name + ) + + lines = result.split("\n") + assert len(lines) == 3 + + # Check first message - time format will depend on local timezone + assert "Alice: Hello everyone!" in lines[0] + assert "Bob: Hi there!" in lines[1] + assert "Charlie: How's everyone doing?" in lines[2] + assert all("[2009-02-13" in line for line in lines) # Check date part + + def test_format_messages_for_content_with_invalid_timestamp(self): + """Test message formatting with invalid timestamp.""" + messages = [ + { + "ts": "invalid_timestamp", + "user": "U123456", + "text": "Hello!", + }, + { + "ts": "", # Empty timestamp + "user": "U789012", + "text": "Hi!", + }, + { + # Missing timestamp + "user": "U456789", + "text": "Hey!", + }, + ] + + def mock_get_display_name(user_id: str) -> str: + return f"User_{user_id}" + + result = SlackParser.format_messages_for_content( + messages, mock_get_display_name + ) + + lines = result.split("\n") + assert len(lines) == 3 + + # Invalid timestamps should be preserved as-is or show as "unknown" + assert "invalid_timestamp" in lines[0] or "unknown" in lines[0] + assert "unknown" in lines[1] or "" in lines[1] + assert "unknown" in lines[2] + + def test_format_messages_for_content_empty_list(self): + """Test message formatting with empty message list.""" + messages = [] + + def mock_get_display_name(user_id: str) -> str: + return "User" + + result = SlackParser.format_messages_for_content( + messages, mock_get_display_name + ) + assert result == "" + + def test_format_messages_for_content_missing_fields(self): + """Test message formatting with missing fields.""" + messages = [ + { + "ts": "1234567890.001", + # Missing user + "text": "Hello!", + }, + { + "ts": "1234567890.002", + "user": "U123456", + # Missing text + }, + { + # Missing everything + }, + ] + + def mock_get_display_name(user_id: str) -> str: + if user_id == "unknown": + return "Unknown User" + return f"User_{user_id}" + + result = SlackParser.format_messages_for_content( + messages, mock_get_display_name + ) + + lines = result.split("\n") + assert len(lines) == 3 + + # Should handle missing fields gracefully + assert "Unknown User" in lines[0] # Missing user should default to "unknown" + assert "User_U123456" in lines[1] + assert lines[2] # Should still produce a line even with missing everything + + def test_validate_person_identifier_with_username(self): + """Test person identifier validation with @username.""" + result = SlackParser.validate_person_identifier("@alice") + assert result == "@alice" + + def test_validate_person_identifier_with_user_id(self): + """Test person identifier validation with user ID.""" + result = SlackParser.validate_person_identifier("U1234567890") + assert result == "<@U1234567890>" + + def test_validate_person_identifier_empty(self): + """Test person identifier validation with empty string.""" + with pytest.raises(ValueError, match="Person identifier cannot be empty"): + SlackParser.validate_person_identifier("") + + def test_validate_person_identifier_invalid_format(self): + """Test person identifier validation with invalid format.""" + # Display name (not supported) + with pytest.raises(ValueError, match="Person identifier .* is invalid"): + SlackParser.validate_person_identifier("Alice Smith") + + # Email (not supported) + with pytest.raises(ValueError, match="Person identifier .* is invalid"): + SlackParser.validate_person_identifier("alice@example.com") + + # Short user ID-like string + with pytest.raises(ValueError, match="Person identifier .* is invalid"): + SlackParser.validate_person_identifier("U123") + + def test_validate_person_identifier_edge_cases(self): + """Test person identifier validation edge cases.""" + # Just @ symbol should be rejected + with pytest.raises(ValueError, match="Username cannot be empty after @"): + SlackParser.validate_person_identifier("@") + + # User ID that doesn't start with U + with pytest.raises(ValueError, match="Person identifier .* is invalid"): + SlackParser.validate_person_identifier("T1234567890") # Team ID format + + # Invalid user ID - too long + with pytest.raises(ValueError, match="Person identifier .* is invalid"): + SlackParser.validate_person_identifier("U123456789012345") # Too long + + # Invalid user ID - too short + with pytest.raises(ValueError, match="Person identifier .* is invalid"): + SlackParser.validate_person_identifier("U123") # Too short + + # Valid edge case usernames + result = SlackParser.validate_person_identifier("@user.name") + assert result == "@user.name" + + result = SlackParser.validate_person_identifier("@user_name") + assert result == "@user_name" + + result = SlackParser.validate_person_identifier("@user-name") + assert result == "@user-name" + + # Invalid username with special characters + with pytest.raises(ValueError, match="Invalid username format"): + SlackParser.validate_person_identifier("@user@name") + + # Whitespace handling + result = SlackParser.validate_person_identifier(" @alice ") + assert result == "@alice" + + +class TestSlackParserIntegration: + """Integration tests for SlackParser methods working together.""" + + def test_message_id_roundtrip(self): + """Test that message ID encoding/decoding is reversible.""" + original_channel = "C1234567890" + original_ts = "1234567890.001" + + # Encode then decode + encoded = SlackParser.encode_message_id(original_channel, original_ts) + decoded_channel, decoded_ts = SlackParser.decode_message_id(encoded) + + assert decoded_channel == original_channel + assert decoded_ts == original_ts + + def test_thread_id_roundtrip(self): + """Test that thread ID encoding/decoding is reversible.""" + original_channel = "C1234567890" + original_ts = "1234567890.001" + + # Encode then decode + encoded = SlackParser.encode_thread_id(original_channel, original_ts) + decoded_channel, decoded_ts = SlackParser.decode_thread_id(encoded) + + assert decoded_channel == original_channel + assert decoded_ts == original_ts + + def test_message_id_vs_thread_id_encoding(self): + """Test that message ID and thread ID encoding produce same results.""" + channel = "C1234567890" + ts = "1234567890.001" + + message_id = SlackParser.encode_message_id(channel, ts) + thread_id = SlackParser.encode_thread_id(channel, ts) + + # Should be identical since they use the same encoding scheme + assert message_id == thread_id + + def test_complex_channel_name_handling(self): + """Test handling of complex channel names with special characters.""" + complex_channel = "C_TEAM_PROJ_123_TEST" + ts = "1234567890.001" + + # Test message ID + encoded_msg = SlackParser.encode_message_id(complex_channel, ts) + decoded_channel, decoded_ts = SlackParser.decode_message_id(encoded_msg) + + assert decoded_channel == complex_channel + assert decoded_ts == ts + + # Test thread ID + encoded_thread = SlackParser.encode_thread_id(complex_channel, ts) + decoded_channel, decoded_ts = SlackParser.decode_thread_id(encoded_thread) + + assert decoded_channel == complex_channel + assert decoded_ts == ts + + def test_user_display_name_with_different_data_structures(self): + """Test user display name extraction with various data structures.""" + test_cases = [ + # Complete profile + { + "input": { + "id": "U123", + "name": "alice", + "real_name": "Alice Smith", + "profile": {"display_name": "Alice S."}, + }, + "expected": "Alice S.", + }, + # Missing profile.display_name + { + "input": { + "id": "U123", + "name": "bob", + "real_name": "Bob Jones", + "profile": {}, + }, + "expected": "Bob Jones", + }, + # Missing profile entirely + { + "input": { + "id": "U123", + "name": "charlie", + "real_name": "Charlie Brown", + }, + "expected": "Charlie Brown", + }, + # Only username + { + "input": { + "id": "U123", + "name": "diana", + }, + "expected": "diana", + }, + # Only ID + { + "input": { + "id": "U123", + }, + "expected": "U123", + }, + # Empty dict + { + "input": {}, + "expected": "unknown", + }, + ] + + for test_case in test_cases: + result = SlackParser.get_user_display_name(test_case["input"]) + assert ( + result == test_case["expected"] + ), f"Failed for input: {test_case['input']}" + + +class TestSlackParserErrorHandling: + """Test error handling and edge cases for SlackParser.""" + + def test_format_messages_with_none_values(self): + """Test message formatting with None values.""" + messages = [ + { + "ts": None, + "user": None, + "text": None, + }, + { + "ts": "1234567890.001", + "user": None, + "text": "Valid message", + }, + ] + + def mock_get_display_name(user_id: str) -> str: + if user_id is None or user_id == "unknown": + return "Unknown" + return f"User_{user_id}" + + # Should not raise exception + result = SlackParser.format_messages_for_content( + messages, mock_get_display_name + ) + assert isinstance(result, str) + + def test_channel_type_with_none_values(self): + """Test channel type determination with None values.""" + channel_info = { + "is_channel": None, + "is_group": None, + "is_im": None, + "is_mpim": None, + } + + result = SlackParser.determine_channel_type(channel_info) + assert result == "unknown" + + def test_user_display_name_with_none_profile(self): + """Test user display name with None profile.""" + user_info = { + "id": "U123", + "name": "alice", + "profile": None, + } + + # Current implementation doesn't handle None profile gracefully + with pytest.raises(AttributeError): + SlackParser.get_user_display_name(user_info)