Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
307 changes: 307 additions & 0 deletions src/scribae/idea.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
from __future__ import annotations

import asyncio
import json
import re
import textwrap
from collections.abc import Callable
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path

from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic_ai import Agent, NativeOutput, UnexpectedModelBehavior
from pydantic_ai.settings import ModelSettings

from .io_utils import NoteDetails, load_note
from .llm import DEFAULT_MODEL_NAME, LLM_OUTPUT_RETRIES, OpenAISettings, make_model
from .project import ProjectConfig

DEFAULT_IDEA_MODEL = DEFAULT_MODEL_NAME
IDEA_TIMEOUT_SECONDS = 300.0


class IdeaError(Exception):
"""Raised when ideas cannot be generated."""

exit_code = 1

def __init__(self, message: str, *, exit_code: int | None = None) -> None:
super().__init__(message)
if exit_code is not None:
self.exit_code = exit_code


class IdeaValidationError(IdeaError):
exit_code = 2


class IdeaFileError(IdeaError):
exit_code = 3


class IdeaLLMError(IdeaError):
exit_code = 4


class Idea(BaseModel):
"""Structured representation of a generated idea."""

model_config = ConfigDict(extra="forbid")

title: str = Field(..., min_length=3)
description: str = Field(..., min_length=10)
why: str = Field(..., min_length=5)

@field_validator("title", "description", "why", mode="before")
@classmethod
def _strip_text(cls, value: str) -> str:
if not isinstance(value, str):
raise TypeError("value must be a string")
return value.strip()


class IdeaList(BaseModel):
"""Container for a collection of ideas."""

model_config = ConfigDict(extra="forbid")

ideas: list[Idea] = Field(default_factory=list, min_length=1)


@dataclass(frozen=True)
class IdeaPromptBundle:
"""Container for the system and user prompts."""

system_prompt: str
user_prompt: str


@dataclass(frozen=True)
class IdeaContext:
"""Artifacts required to generate a list of ideas."""

note: NoteDetails
project: ProjectConfig
prompts: IdeaPromptBundle


Reporter = Callable[[str], None] | None

IDEA_SYSTEM_PROMPT = textwrap.dedent(
"""
You are a creative strategist who proposes concise, audience-aware content ideas.
Output must be a pure JSON object with an "ideas" array, no prose or Markdown.
Each idea object must include exactly these fields: "title", "description", "why".
Keep titles concise and avoid numbered prefixes.
"""
).strip()


def prepare_context(
note_path: Path,
*,
project: ProjectConfig,
max_chars: int,
reporter: Reporter = None,
) -> IdeaContext:
"""Load note data and assemble prompt context."""

if max_chars <= 0:
raise IdeaValidationError("--max-chars must be greater than zero.")

try:
note = load_note(note_path, max_chars=max_chars)
except FileNotFoundError as exc:
raise IdeaFileError(f"Note file not found: {note_path}") from exc
except ValueError as exc:
raise IdeaFileError(str(exc)) from exc
except OSError as exc: # pragma: no cover - surfaced by CLI
raise IdeaFileError(f"Unable to read note: {exc}") from exc

_report(reporter, f"Loaded note '{note.title}' from {note.path}")

prompts = IdeaPromptBundle(
system_prompt=IDEA_SYSTEM_PROMPT,
user_prompt=_build_user_prompt(project=project, note_title=note.title, note_content=note.body),
)
_report(reporter, "Prepared idea-generation prompt.")

return IdeaContext(note=note, project=project, prompts=prompts)


def generate_ideas(
context: IdeaContext,
*,
model_name: str,
temperature: float,
reporter: Reporter = None,
settings: OpenAISettings | None = None,
agent: Agent[None, IdeaList] | None = None,
timeout_seconds: float = IDEA_TIMEOUT_SECONDS,
) -> IdeaList:
"""Run the LLM call and return validated ideas."""

resolved_settings = settings or OpenAISettings.from_env()
llm_agent: Agent[None, IdeaList] = (
_create_agent(model_name, resolved_settings, temperature=temperature) if agent is None else agent
)

_report(reporter, f"Calling model '{model_name}' via {resolved_settings.base_url}")

try:
ideas = _invoke_agent(llm_agent, context.prompts.user_prompt, timeout_seconds=timeout_seconds)
except UnexpectedModelBehavior as exc:
raise IdeaValidationError(
"LLM response never satisfied the idea list schema, giving up after repeated retries."
) from exc
except IdeaLLMError:
raise
except Exception as exc: # pragma: no cover - surfaced to CLI
raise IdeaLLMError(f"LLM request failed: {exc}") from exc

_report(reporter, "LLM call complete, ideas validated.")
return ideas


def render_json(result: IdeaList) -> str:
"""Return the ideas as a JSON string."""

return json.dumps(result.model_dump(), indent=2, ensure_ascii=False)


def save_prompt_artifacts(
context: IdeaContext,
*,
destination: Path,
project_label: str,
timestamp: str | None = None,
) -> tuple[Path, Path]:
"""Persist the system prompt and truncated note for debugging."""

destination.mkdir(parents=True, exist_ok=True)
stamp = timestamp or _current_timestamp()
slug = _slugify(project_label or "default") or "default"

prompt_path = destination / f"{stamp}-{slug}-ideas.prompt.txt"
note_path = destination / f"{stamp}-note.txt"

prompt_payload = (
f"SYSTEM PROMPT:\n{context.prompts.system_prompt}\n\nUSER PROMPT:\n{context.prompts.user_prompt}\n"
)
prompt_path.write_text(prompt_payload, encoding="utf-8")
note_path.write_text(context.note.body, encoding="utf-8")

return prompt_path, note_path


def _build_user_prompt(*, project: ProjectConfig, note_title: str, note_content: str) -> str:
keywords = ", ".join(project["keywords"]) if project["keywords"] else "none"
allowed_tags = ", ".join(project["allowed_tags"] or []) if project["allowed_tags"] else "any"

template = textwrap.dedent(
"""
[PROJECT CONTEXT]
Site: {site_name} ({domain})
Audience: {audience}
Tone: {tone}
FocusKeywords: {keywords}
AllowedTags: {allowed_tags}
Language: {language}

[TASK]
Propose 5–8 content ideas grounded in the note. Avoid generic listicles or duplicative angles.
Each idea must include:
- title: 5–12 words capturing the core hook.
- description: 2–3 sentences describing the article or asset.
- why: 1–2 sentences explaining why this idea fits the audience and project goals.
Respond with a JSON object containing an "ideas" array of idea objects, nothing else.

[NOTE TITLE]
{note_title}

[NOTE CONTENT]
{note_content}

JSON only. The root object must contain an "ideas" array with at least 5 entries.
"""
).strip()

return template.format(
site_name=project["site_name"],
domain=project["domain"],
audience=project["audience"],
tone=project["tone"],
keywords=keywords,
allowed_tags=allowed_tags,
language=project["language"],
note_title=note_title.strip(),
note_content=note_content.strip(),
)


def _create_agent(model_name: str, settings: OpenAISettings, *, temperature: float) -> Agent[None, IdeaList]:
"""Instantiate the Pydantic AI agent for generating ideas."""

settings.configure_environment()
model_settings = ModelSettings(temperature=temperature)
model = make_model(model_name, model_settings=model_settings, settings=settings)
return Agent[None, IdeaList](
model=model,
output_type=NativeOutput(IdeaList, name="IdeaList", strict=True),
instructions=IDEA_SYSTEM_PROMPT,
output_retries=LLM_OUTPUT_RETRIES,
)


def _invoke_agent(agent: Agent[None, IdeaList], prompt: str, *, timeout_seconds: float) -> IdeaList:
"""Run the agent with a timeout using asyncio."""

async def _call() -> IdeaList:
run = await agent.run(prompt)
output = getattr(run, "output", None)
if isinstance(output, IdeaList):
return output
if isinstance(output, BaseModel):
return IdeaList.model_validate(output.model_dump())
if isinstance(output, list):
return IdeaList.model_validate({"ideas": output})
if isinstance(output, dict):
return IdeaList.model_validate(output)
raise TypeError("LLM output is not an IdeaList instance")

return asyncio.run(asyncio.wait_for(_call(), timeout_seconds))


def _current_timestamp() -> str:
return datetime.now().strftime("%Y%m%d-%H%M%S")


def _slugify(value: str) -> str:
lowered = value.lower()
return re.sub(r"[^a-z0-9]+", "-", lowered).strip("-")


def _report(reporter: Reporter, message: str) -> None:
"""Send verbose output when enabled."""

if reporter:
reporter(message)


__all__ = [
"DEFAULT_IDEA_MODEL",
"Idea",
"IdeaContext",
"IdeaError",
"IdeaFileError",
"IdeaLLMError",
"IdeaList",
"IdeaPromptBundle",
"IdeaValidationError",
"OpenAISettings",
"generate_ideas",
"prepare_context",
"render_json",
"save_prompt_artifacts",
]
Loading