diff --git a/src/api/routers/community.py b/src/api/routers/community.py index b6e7064..37040b1 100644 --- a/src/api/routers/community.py +++ b/src/api/routers/community.py @@ -34,7 +34,7 @@ from src.assistants.registry import AssistantInfo from src.core.config.community import WidgetConfig from src.core.services.litellm_llm import create_openrouter_llm -from src.knowledge.search import FAQResult, list_faq_entries +from src.knowledge.search import FAQResult, get_citation_stats, list_faq_entries from src.metrics.cost import COST_BLOCK_THRESHOLD, COST_WARN_THRESHOLD, MODEL_PRICING, estimate_cost from src.metrics.db import ( RequestLogEntry, @@ -229,6 +229,23 @@ class FAQFeedResponse(BaseModel): entries: list[FAQEntryResponse] = Field(default_factory=list, description="FAQ entries") +class CitationsFeedResponse(BaseModel): + """Public citation dashboard data for a community's canonical papers.""" + + community_id: str = Field(..., description="Community identifier") + total: int = Field(..., description="Total citing papers with a recorded canonical link") + per_year: dict[str, int] = Field( + default_factory=dict, description="Citing-paper count per year across all papers" + ) + by_paper: dict[str, dict[str, int]] = Field( + default_factory=dict, + description="Stacked breakdown: canonical DOI -> year -> citing-paper count", + ) + canonical_dois: list[str] = Field( + default_factory=list, description="Canonical DOIs tracked for this community" + ) + + # Matches bare email addresses so they can be stripped from the public feed. _EMAIL_PATTERN = re.compile(r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}") @@ -1621,6 +1638,50 @@ async def community_faq( entries=[_faq_result_to_response(e) for e in entries], ) + @router.get("/citations", response_model=CitationsFeedResponse) + async def community_citations(response: Response) -> CitationsFeedResponse: + """Public, read-only citation dashboard for this community. + + Returns per-year counts of papers citing the community's canonical + works, plus a stacked breakdown keyed by the cited DOI (the shape + behind a citations-per-year chart). Disabled by default; a community + opts in via ``public_feeds.citations: true`` in its config. + """ + config = info.community_config + if config is None or config.public_feeds is None or not config.public_feeds.citations: + raise HTTPException( + status_code=404, + detail="Public citations feed is not enabled for this community.", + ) + + try: + stats = get_citation_stats(project=community_id) + except sqlite3.Error: + logger.exception("Failed to query citations for community %s", community_id) + raise HTTPException( + status_code=503, + detail="Knowledge database is temporarily unavailable.", + ) + except Exception: + logger.exception( + "Unexpected error serving citations feed for community %s", community_id + ) + raise HTTPException( + status_code=500, + detail="An unexpected error occurred while building the citations feed.", + ) + + canonical_dois = list(config.citations.dois) if config.citations else [] + + response.headers["Cache-Control"] = "public, max-age=3600" + return CitationsFeedResponse( + community_id=community_id, + total=stats.total, + per_year=stats.per_year, + by_paper=stats.by_paper, + canonical_dois=canonical_dois, + ) + return router diff --git a/src/knowledge/db.py b/src/knowledge/db.py index 5c9166d..ba2dfb6 100644 --- a/src/knowledge/db.py +++ b/src/knowledge/db.py @@ -132,6 +132,9 @@ def active_mirror_context(mirror_id: str) -> Iterator[None]: url TEXT NOT NULL, created_at TEXT, synced_at TEXT NOT NULL, + -- Canonical DOI this paper cites, when discovered via citation sync. + -- NULL for papers found through keyword search rather than a citation link. + cites_doi TEXT, UNIQUE(source, external_id) ); @@ -409,6 +412,8 @@ def active_mirror_context(mirror_id: str) -> Iterator[None]: CREATE INDEX IF NOT EXISTS idx_github_items_status ON github_items(status); CREATE INDEX IF NOT EXISTS idx_github_items_type ON github_items(item_type); CREATE INDEX IF NOT EXISTS idx_papers_source ON papers(source); +-- idx_papers_cites_doi is created in _migrate_db, after the cites_doi column +-- is ensured, so init_db stays safe on databases predating that column. CREATE INDEX IF NOT EXISTS idx_docstrings_repo ON docstrings(repo); CREATE INDEX IF NOT EXISTS idx_docstrings_language ON docstrings(language); CREATE INDEX IF NOT EXISTS idx_messages_list ON mailing_list_messages(list_name); @@ -507,6 +512,28 @@ def _migrate_db(conn: sqlite3.Connection) -> None: # Table doesn't exist yet - this is fine, schema will create it logger.debug("Docstrings table not found during migration (will be created): %s", e) + # Migration: Add cites_doi column to papers table (added 2026-06-09). + # The index lives here (not in SCHEMA_SQL) so executescript never references + # cites_doi on a database created before the column existed. + try: + cursor = conn.execute("PRAGMA table_info(papers)") + columns = [row[1] for row in cursor.fetchall()] + except sqlite3.OperationalError as e: + # Only the PRAGMA is guarded here: a missing papers table is fine since + # SCHEMA_SQL creates it. DDL errors below (locked DB, I/O fault) must + # propagate rather than be swallowed and leave the table un-indexed. + logger.debug("Papers table not found during migration (will be created): %s", e) + columns = [] + + if columns: # papers table exists; migrate it in place + if "cites_doi" not in columns: + logger.info("Migrating papers table: adding cites_doi column") + conn.execute("ALTER TABLE papers ADD COLUMN cites_doi TEXT") + logger.info("Migration complete: cites_doi column added to papers") + # Ensure the index exists for both new and migrated databases. + conn.execute("CREATE INDEX IF NOT EXISTS idx_papers_cites_doi ON papers(cites_doi)") + conn.commit() + def init_db(project: str = "hed") -> None: """Initialize database schema for a project. @@ -586,6 +613,7 @@ def upsert_paper( first_message: str | None, url: str, created_at: str | None, + cites_doi: str | None = None, ) -> None: """Insert or update a paper. @@ -597,6 +625,14 @@ def upsert_paper( first_message: Abstract (limited to ~2000 chars) url: URL to the paper (DOI or source URL) created_at: Publication date (ISO 8601 or year string) + cites_doi: Canonical DOI this paper cites, when known from a citation + sync. ``None`` for keyword-search results. On conflict the first + recorded link is kept (COALESCE), so a later keyword sync passing + ``None`` never erases an existing citation link, and a re-sync + backfills the link onto rows stored before this column existed. + A single column holds one link: a paper citing two tracked DOIs is + attributed to whichever was synced first (it is still counted once + in the per-year total, only its by-paper bucket is approximate). """ # Limit first_message size if first_message and len(first_message) > 2000: @@ -605,14 +641,15 @@ def upsert_paper( conn.execute( """ INSERT INTO papers (source, external_id, title, first_message, - status, url, created_at, synced_at) - VALUES (?, ?, ?, ?, 'published', ?, ?, ?) + status, url, created_at, synced_at, cites_doi) + VALUES (?, ?, ?, ?, 'published', ?, ?, ?, ?) ON CONFLICT(source, external_id) DO UPDATE SET title=excluded.title, first_message=excluded.first_message, - synced_at=excluded.synced_at + synced_at=excluded.synced_at, + cites_doi=COALESCE(papers.cites_doi, excluded.cites_doi) """, - (source, external_id, title, first_message, url, created_at, _now_iso()), + (source, external_id, title, first_message, url, created_at, _now_iso(), cites_doi), ) diff --git a/src/knowledge/papers_sync.py b/src/knowledge/papers_sync.py index a83806b..f185e27 100644 --- a/src/knowledge/papers_sync.py +++ b/src/knowledge/papers_sync.py @@ -158,6 +158,7 @@ def _store_papers( project: str, *, force_source: str | None = None, + cites_doi: str | None = None, ) -> dict[str, int]: """Upsert opencite papers into the knowledge DB, returning counts by source. @@ -167,6 +168,8 @@ def _store_papers( force_source: When set (a single-source sync), record this OSA source label using its native identifier; falls back to the priority mapping if that identifier is missing. + cites_doi: Canonical DOI these papers cite, recorded on each row when + storing the results of a citation sync. ``None`` for keyword search. """ counts: dict[str, int] = {} with get_connection(project) as conn: @@ -193,6 +196,7 @@ def _store_papers( first_message=paper.abstract or None, url=_paper_url(paper), created_at=paper.publication_date or (str(paper.year) if paper.year else None), + cites_doi=cites_doi, ) counts[source] = counts.get(source, 0) + 1 conn.commit() @@ -420,7 +424,7 @@ def sync_citing_papers( total = 0 for doi, papers in cited: try: - counts = _store_papers(papers, project) + counts = _store_papers(papers, project, cites_doi=doi) count = sum(counts.values()) update_sync_metadata("papers", f"citing_{doi}", count, project) logger.info("Synced %d papers citing %s", count, doi) diff --git a/src/knowledge/search.py b/src/knowledge/search.py index 2b3fdb5..c8d0b7a 100644 --- a/src/knowledge/search.py +++ b/src/knowledge/search.py @@ -376,6 +376,77 @@ def search_github_items( return results +@dataclass +class CitationStats: + """Aggregated citation counts for a community's canonical papers.""" + + total: int + """Total citing papers with a recorded canonical link and a valid year.""" + + per_year: dict[str, int] + """Citing-paper count per publication year, summed across canonical DOIs.""" + + by_paper: dict[str, dict[str, int]] + """Per canonical DOI: a mapping of publication year to citing-paper count.""" + + +def get_citation_stats(project: str = "eeglab") -> CitationStats: + """Aggregate citation counts for the public citations dashboard. + + Counts papers that cite a community's canonical DOIs (``papers.cites_doi`` + is set), grouped by the citing paper's publication year. The year is the + leading four digits of ``created_at`` (ISO date or bare year); rows whose + ``created_at`` is missing or not a four-digit year are skipped so a bad + date never lands in a bogus year bucket. + + Args: + project: Community ID for database isolation. Defaults to 'eeglab'. + + Returns: + CitationStats with the overall ``total``, ``per_year`` totals, and the + stacked ``by_paper`` breakdown (canonical DOI -> year -> count). Years + are sorted ascending in every mapping. + """ + sql = """ + SELECT cites_doi, substr(created_at, 1, 4) AS yr, COUNT(*) AS cnt + FROM papers + WHERE cites_doi IS NOT NULL + AND created_at IS NOT NULL + AND substr(created_at, 1, 4) GLOB '[0-9][0-9][0-9][0-9]' + GROUP BY cites_doi, yr + """ + + per_year: dict[str, int] = {} + by_paper: dict[str, dict[str, int]] = {} + total = 0 + try: + with get_connection(project) as conn: + for row in conn.execute(sql): + doi = row["cites_doi"] + year = row["yr"] + count = row["cnt"] + per_year[year] = per_year.get(year, 0) + count + by_paper.setdefault(doi, {})[year] = count + total += count + except sqlite3.OperationalError as e: + logger.error( + "Database operational error computing citation stats: %s", + e, + exc_info=True, + extra={"project": project}, + ) + raise + except sqlite3.Error as e: + logger.warning("Database error computing citation stats (project=%s): %s", project, e) + raise + + return CitationStats( + total=total, + per_year=dict(sorted(per_year.items())), + by_paper={doi: dict(sorted(years.items())) for doi, years in by_paper.items()}, + ) + + def search_papers( query: str, project: str = "hed", diff --git a/tests/test_api/test_citations_feed.py b/tests/test_api/test_citations_feed.py new file mode 100644 index 0000000..bbf0e6b --- /dev/null +++ b/tests/test_api/test_citations_feed.py @@ -0,0 +1,204 @@ +"""Tests for the public citations feed endpoint: GET /{community_id}/citations. + +Uses a real registered community, a temporary SQLite knowledge database with +citing papers, and the config gate toggled per test. No business logic is +mocked except in TestCitationsFeedErrors, where get_citation_stats is patched +at the router call boundary to inject DB/unexpected errors and verify the +503/500 responses. +""" + +import sqlite3 +from collections.abc import Iterator +from pathlib import Path +from unittest.mock import patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from src.api.routers.community import create_community_router +from src.assistants import discover_assistants, registry +from src.core.config.community import PublicFeedsConfig +from src.knowledge.db import get_connection, init_db, upsert_paper + +COMMUNITY_ID = "eeglab" +DOI_A = "10.1016/j.jneumeth.2003.10.009" +DOI_B = "10.1016/j.neuroimage.2019.05.026" + +discover_assistants() + + +@pytest.fixture +def citations_db(tmp_path: Path) -> Iterator[Path]: + """Temp knowledge DB with citing papers across two canonical DOIs.""" + db_path = tmp_path / "knowledge" / "test.db" + with patch("src.knowledge.db.get_db_path", return_value=db_path): + init_db(COMMUNITY_ID) + with get_connection(COMMUNITY_ID) as conn: + rows = [ + ("a1", "2019-05-01", DOI_A), + ("a2", "2019-11-20", DOI_A), + ("a3", "2020", DOI_A), + ("b1", "2020-02-02", DOI_B), + ("k1", "2021", None), # keyword-only, excluded from stats + ] + for external_id, created_at, cites_doi in rows: + upsert_paper( + conn, + source="openalex", + external_id=external_id, + title=f"Paper {external_id}", + first_message=None, + url=f"https://doi.org/10.test/{external_id}", + created_at=created_at, + cites_doi=cites_doi, + ) + conn.commit() + yield db_path + + +@pytest.fixture +def citations_enabled() -> Iterator[None]: + info = registry.get(COMMUNITY_ID) + assert info is not None and info.community_config is not None + original = info.community_config.public_feeds + info.community_config.public_feeds = PublicFeedsConfig(citations=True) + try: + yield + finally: + info.community_config.public_feeds = original + + +@pytest.fixture +def citations_disabled_none() -> Iterator[None]: + info = registry.get(COMMUNITY_ID) + assert info is not None and info.community_config is not None + original = info.community_config.public_feeds + info.community_config.public_feeds = None + try: + yield + finally: + info.community_config.public_feeds = original + + +@pytest.fixture +def citations_flag_false() -> Iterator[None]: + info = registry.get(COMMUNITY_ID) + assert info is not None and info.community_config is not None + original = info.community_config.public_feeds + info.community_config.public_feeds = PublicFeedsConfig(citations=False) + try: + yield + finally: + info.community_config.public_feeds = original + + +@pytest.fixture +def citations_enabled_no_config() -> Iterator[None]: + """Feed enabled but the community has no citations config block.""" + info = registry.get(COMMUNITY_ID) + assert info is not None and info.community_config is not None + orig_feeds = info.community_config.public_feeds + orig_citations = info.community_config.citations + info.community_config.public_feeds = PublicFeedsConfig(citations=True) + info.community_config.citations = None + try: + yield + finally: + info.community_config.public_feeds = orig_feeds + info.community_config.citations = orig_citations + + +@pytest.fixture +def client() -> TestClient: + app = FastAPI() + app.include_router(create_community_router(COMMUNITY_ID)) + return TestClient(app) + + +class TestCitationsFeedGate: + """The endpoint is opt-in via public_feeds.citations.""" + + @pytest.mark.usefixtures("citations_disabled_none") + def test_disabled_when_public_feeds_none(self, client, citations_db): + with patch("src.knowledge.db.get_db_path", return_value=citations_db): + resp = client.get(f"/{COMMUNITY_ID}/citations") + assert resp.status_code == 404 + + @pytest.mark.usefixtures("citations_flag_false") + def test_disabled_when_flag_false(self, client, citations_db): + with patch("src.knowledge.db.get_db_path", return_value=citations_db): + resp = client.get(f"/{COMMUNITY_ID}/citations") + assert resp.status_code == 404 + + @pytest.mark.usefixtures("citations_enabled") + def test_enabled_returns_200(self, client, citations_db): + with patch("src.knowledge.db.get_db_path", return_value=citations_db): + resp = client.get(f"/{COMMUNITY_ID}/citations") + assert resp.status_code == 200 + + +@pytest.mark.usefixtures("citations_enabled") +class TestCitationsFeedContent: + def test_total_and_per_year(self, client, citations_db): + with patch("src.knowledge.db.get_db_path", return_value=citations_db): + resp = client.get(f"/{COMMUNITY_ID}/citations") + body = resp.json() + assert body["community_id"] == COMMUNITY_ID + assert body["total"] == 4 # a1,a2,a3,b1 ; k1 unlinked excluded + assert body["per_year"] == {"2019": 2, "2020": 2} + + def test_by_paper_stacked_breakdown(self, client, citations_db): + with patch("src.knowledge.db.get_db_path", return_value=citations_db): + resp = client.get(f"/{COMMUNITY_ID}/citations") + by_paper = resp.json()["by_paper"] + assert by_paper == { + DOI_A: {"2019": 2, "2020": 1}, + DOI_B: {"2020": 1}, + } + + def test_canonical_dois_from_config(self, client, citations_db): + with patch("src.knowledge.db.get_db_path", return_value=citations_db): + resp = client.get(f"/{COMMUNITY_ID}/citations") + canonical = resp.json()["canonical_dois"] + # eeglab config tracks these canonical DOIs. + assert DOI_A in canonical + assert DOI_B in canonical + + def test_cache_control_header(self, client, citations_db): + with patch("src.knowledge.db.get_db_path", return_value=citations_db): + resp = client.get(f"/{COMMUNITY_ID}/citations") + assert resp.headers["Cache-Control"] == "public, max-age=3600" + + +class TestCitationsFeedNoConfig: + """Feed enabled for a community without a citations config block.""" + + @pytest.mark.usefixtures("citations_enabled_no_config") + def test_canonical_dois_empty_when_no_citations_config(self, client, citations_db): + with patch("src.knowledge.db.get_db_path", return_value=citations_db): + resp = client.get(f"/{COMMUNITY_ID}/citations") + body = resp.json() + assert resp.status_code == 200 + assert body["canonical_dois"] == [] + # Stats still come from the DB regardless of config presence. + assert body["total"] == 4 + + +@pytest.mark.usefixtures("citations_enabled") +class TestCitationsFeedErrors: + def test_db_error_returns_503(self, client): + with patch( + "src.api.routers.community.get_citation_stats", + side_effect=sqlite3.OperationalError("db is locked"), + ): + resp = client.get(f"/{COMMUNITY_ID}/citations") + assert resp.status_code == 503 + + def test_unexpected_error_returns_500(self, client): + with patch( + "src.api.routers.community.get_citation_stats", + side_effect=RuntimeError("boom"), + ): + resp = client.get(f"/{COMMUNITY_ID}/citations") + assert resp.status_code == 500 diff --git a/tests/test_knowledge/test_citation_stats.py b/tests/test_knowledge/test_citation_stats.py new file mode 100644 index 0000000..4d828cb --- /dev/null +++ b/tests/test_knowledge/test_citation_stats.py @@ -0,0 +1,179 @@ +"""Tests for citation stats aggregation and the cites_doi linkage column. + +Uses a real temporary SQLite database (only the DB path is redirected); no +business logic is mocked. +""" + +from pathlib import Path +from unittest.mock import patch + +import pytest + +from src.knowledge.db import get_connection, init_db, upsert_paper +from src.knowledge.search import CitationStats, get_citation_stats + +DOI_A = "10.1016/j.jneumeth.2003.10.009" +DOI_B = "10.1016/j.neuroimage.2019.05.026" + + +def _add_paper(conn, external_id, *, created_at, cites_doi=None, source="openalex"): + upsert_paper( + conn, + source=source, + external_id=external_id, + title=f"Citing paper {external_id}", + first_message=None, + url=f"https://doi.org/10.test/{external_id}", + created_at=created_at, + cites_doi=cites_doi, + ) + + +@pytest.fixture +def citations_db(tmp_path: Path): + """Temp DB with citing papers across two canonical DOIs and several years.""" + db_path = tmp_path / "knowledge" / "test.db" + with patch("src.knowledge.db.get_db_path", return_value=db_path): + init_db() + with get_connection() as conn: + # DOI_A: 2 in 2019, 1 in 2020 + _add_paper(conn, "a1", created_at="2019-05-01", cites_doi=DOI_A) + _add_paper(conn, "a2", created_at="2019-11-20", cites_doi=DOI_A) + _add_paper(conn, "a3", created_at="2020", cites_doi=DOI_A) + # DOI_B: 1 in 2020, 1 in 2021 + _add_paper(conn, "b1", created_at="2020-02-02", cites_doi=DOI_B) + _add_paper(conn, "b2", created_at="2021-07-07", cites_doi=DOI_B) + # Keyword-search paper (no citation link) - excluded from stats + _add_paper(conn, "k1", created_at="2022", cites_doi=None) + # Citing paper with an unusable date - excluded from year buckets + _add_paper(conn, "x1", created_at="", cites_doi=DOI_A) + _add_paper(conn, "x2", created_at=None, cites_doi=DOI_B) + conn.commit() + yield db_path + + +class TestGetCitationStats: + def test_returns_citation_stats_object(self, citations_db: Path): + with patch("src.knowledge.db.get_db_path", return_value=citations_db): + stats = get_citation_stats(project="eeglab") + assert isinstance(stats, CitationStats) + + def test_total_excludes_unlinked_and_undated(self, citations_db: Path): + with patch("src.knowledge.db.get_db_path", return_value=citations_db): + stats = get_citation_stats(project="eeglab") + # 5 linked papers with valid years (a1,a2,a3,b1,b2); k1 unlinked, + # x1/x2 undated are excluded. + assert stats.total == 5 + + def test_per_year_aggregates_across_dois(self, citations_db: Path): + with patch("src.knowledge.db.get_db_path", return_value=citations_db): + stats = get_citation_stats(project="eeglab") + assert stats.per_year == {"2019": 2, "2020": 2, "2021": 1} + + def test_per_year_is_sorted_ascending(self, citations_db: Path): + with patch("src.knowledge.db.get_db_path", return_value=citations_db): + stats = get_citation_stats(project="eeglab") + assert list(stats.per_year.keys()) == sorted(stats.per_year.keys()) + + def test_by_paper_stacked_breakdown(self, citations_db: Path): + with patch("src.knowledge.db.get_db_path", return_value=citations_db): + stats = get_citation_stats(project="eeglab") + assert stats.by_paper == { + DOI_A: {"2019": 2, "2020": 1}, + DOI_B: {"2020": 1, "2021": 1}, + } + + def test_empty_database(self, tmp_path: Path): + db_path = tmp_path / "knowledge" / "empty.db" + with patch("src.knowledge.db.get_db_path", return_value=db_path): + init_db() + stats = get_citation_stats(project="eeglab") + assert stats.total == 0 + assert stats.per_year == {} + assert stats.by_paper == {} + + +class TestCitesDoiUpsert: + def test_backfill_sets_link_on_existing_row(self, tmp_path: Path): + """A row first stored without a link gets it on a later citation sync.""" + db_path = tmp_path / "knowledge" / "test.db" + with patch("src.knowledge.db.get_db_path", return_value=db_path): + init_db() + with get_connection() as conn: + _add_paper(conn, "p1", created_at="2020", cites_doi=None) + _add_paper(conn, "p1", created_at="2020", cites_doi=DOI_A) + conn.commit() + row = conn.execute( + "SELECT cites_doi FROM papers WHERE external_id = 'p1'" + ).fetchone() + assert row["cites_doi"] == DOI_A + + def test_first_link_wins_over_later_link(self, tmp_path: Path): + """COALESCE keeps the first recorded canonical DOI for overlapping papers.""" + db_path = tmp_path / "knowledge" / "test.db" + with patch("src.knowledge.db.get_db_path", return_value=db_path): + init_db() + with get_connection() as conn: + _add_paper(conn, "p1", created_at="2020", cites_doi=DOI_A) + _add_paper(conn, "p1", created_at="2020", cites_doi=DOI_B) + conn.commit() + row = conn.execute( + "SELECT cites_doi FROM papers WHERE external_id = 'p1'" + ).fetchone() + assert row["cites_doi"] == DOI_A + + def test_keyword_sync_does_not_erase_link(self, tmp_path: Path): + """A later keyword sync (cites_doi=None) must not clobber an existing link.""" + db_path = tmp_path / "knowledge" / "test.db" + with patch("src.knowledge.db.get_db_path", return_value=db_path): + init_db() + with get_connection() as conn: + _add_paper(conn, "p1", created_at="2020", cites_doi=DOI_A) + _add_paper(conn, "p1", created_at="2020", cites_doi=None) + conn.commit() + row = conn.execute( + "SELECT cites_doi FROM papers WHERE external_id = 'p1'" + ).fetchone() + assert row["cites_doi"] == DOI_A + + +class TestCitesDoiMigration: + def test_migration_adds_column_to_legacy_papers_table(self, tmp_path: Path): + """A papers table created before cites_doi gains the column via init_db.""" + db_path = tmp_path / "knowledge" / "legacy.db" + with patch("src.knowledge.db.get_db_path", return_value=db_path): + # Simulate a pre-migration schema: papers without cites_doi. + with get_connection() as conn: + conn.execute( + """ + CREATE TABLE papers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + source TEXT NOT NULL, + external_id TEXT NOT NULL, + title TEXT NOT NULL, + first_message TEXT, + status TEXT NOT NULL DEFAULT 'published', + url TEXT NOT NULL, + created_at TEXT, + synced_at TEXT NOT NULL, + UNIQUE(source, external_id) + ) + """ + ) + conn.commit() + cols_before = [r[1] for r in conn.execute("PRAGMA table_info(papers)")] + assert "cites_doi" not in cols_before + + # Running init_db must migrate the existing table in place. + init_db() + with get_connection() as conn: + cols_after = [r[1] for r in conn.execute("PRAGMA table_info(papers)")] + # The new column is usable for inserts after migration. + _add_paper(conn, "p1", created_at="2020", cites_doi=DOI_A) + conn.commit() + row = conn.execute( + "SELECT cites_doi FROM papers WHERE external_id = 'p1'" + ).fetchone() + + assert "cites_doi" in cols_after + assert row["cites_doi"] == DOI_A diff --git a/tests/test_knowledge/test_papers_sync.py b/tests/test_knowledge/test_papers_sync.py index b23740c..edf45c3 100644 --- a/tests/test_knowledge/test_papers_sync.py +++ b/tests/test_knowledge/test_papers_sync.py @@ -165,6 +165,21 @@ def test_upsert_deduplicates_same_paper(self, temp_db: Path): count = conn.execute("SELECT COUNT(*) AS c FROM papers").fetchone()["c"] assert count == 1 + def test_stores_cites_doi_on_each_row(self, temp_db: Path): + # A citation sync threads the canonical DOI through to each stored row. + papers = [ + Paper(title="Citing A", ids=IDSet(openalex_id="https://openalex.org/W1"), year=2023), + Paper(title="Citing B", ids=IDSet(openalex_id="https://openalex.org/W2"), year=2024), + ] + with patch("src.knowledge.db.get_db_path", return_value=temp_db): + _store_papers(papers, "test", cites_doi="10.1/canonical") + with get_connection("test") as conn: + links = { + r["external_id"]: r["cites_doi"] + for r in conn.execute("SELECT external_id, cites_doi FROM papers") + } + assert links == {"W1": "10.1/canonical", "W2": "10.1/canonical"} + def test_force_source_uses_native_id(self, temp_db: Path): # A PubMed-restricted sync should label the row 'pubmed' using the PMID, # even though the paper also carries an OpenAlex id.