From 9f1043e1cca43a65fc7487b0e97792f246002291 Mon Sep 17 00:00:00 2001 From: Tiankai Ma Date: Thu, 28 May 2026 13:30:37 +0800 Subject: [PATCH 1/9] refactor(auth): unify oauth discovery and sign-in flows --- .env.example | 6 +- docs/features/oauth.json | 1 + docs/features/user.json | 1 + messages/en-us.json | 13 + messages/zh-cn.json | 13 + .../api/auth/route.ts | 9 +- .../api/mcp/route.ts | 7 +- .../oauth-authorization-server/route.ts | 7 +- .../oauth-protected-resource/api/mcp/route.ts | 41 +- .../oauth-protected-resource/route.ts | 7 +- .../openid-configuration/api/auth/route.ts | 9 +- .../openid-configuration/api/mcp/route.ts | 7 +- .../.well-known/openid-configuration/route.ts | 7 +- src/app/actions/auth.ts | 19 + src/app/actions/oauth.ts | 109 ++--- .../oauth/oauth-client-create-dialog.tsx | 9 +- src/app/admin/oauth/oauth-client-list.tsx | 284 +++++++------ .../oauth/oauth-client-manager-shared.ts | 75 ++-- src/app/admin/oauth/oauth-client-manager.tsx | 12 +- src/app/admin/oauth/oauth-client-overview.tsx | 147 ++----- src/app/admin/oauth/page.tsx | 3 +- .../.well-known/openid-configuration/route.ts | 9 +- src/app/api/auth/oauth2/token/route.ts | 27 +- .../oauth-authorization-server/route.ts | 7 +- .../.well-known/openid-configuration/route.ts | 7 +- src/app/oauth/device/actions.ts | 82 ++-- src/app/oauth/device/page.tsx | 325 +++++++-------- src/app/signin/page.tsx | 18 +- src/app/u/profile-view.tsx | 16 +- src/auth.ts | 4 +- src/components/sign-in-link.tsx | 22 + src/components/user-menu.tsx | 26 +- src/env.ts | 382 +++++------------- src/lib/auth/auth-config.ts | 66 +-- src/lib/auth/auth-origins.ts | 70 ++-- src/lib/auth/auth-routing.ts | 29 +- src/lib/auth/better-auth-options.ts | 116 +++--- src/lib/auth/debug-auth.ts | 143 ++++--- src/lib/auth/helpers.ts | 21 +- src/lib/auth/oauth-profile.ts | 102 +++-- src/lib/auth/provider-ids.ts | 11 +- src/lib/mcp/auth.ts | 47 +-- src/lib/oauth/client-registration.ts | 43 +- src/lib/oauth/constants.ts | 44 ++ src/lib/oauth/discovery-metadata.ts | 36 +- src/lib/oauth/discovery-routes.ts | 65 +++ src/lib/oauth/logging.ts | 91 ----- src/lib/oauth/loopback-redirect.ts | 61 +-- src/lib/oauth/provider-api.ts | 53 +-- src/lib/oauth/redirect.ts | 21 - src/lib/oauth/utils.ts | 14 - src/lib/site-url.ts | 56 +-- tests/e2e/src/app/admin/oauth/test.ts | 16 +- tests/e2e/src/app/api/oauth/register/test.ts | 41 +- tests/e2e/src/app/signin/test.ts | 52 ++- tests/e2e/src/app/welcome/test.ts | 2 +- tests/e2e/utils/auth.ts | 1 + tests/e2e/utils/e2e-db/oauth.ts | 37 +- tests/unit/auth-config.test.ts | 44 ++ tests/unit/auth-helpers.test.ts | 14 +- tests/unit/auth-origins.test.ts | 35 +- tests/unit/auth-provider-routing.test.ts | 11 + tests/unit/debug-auth.test.ts | 65 +++ tests/unit/env.test.ts | 63 ++- tests/unit/oauth-client-registration.test.ts | 65 +++ tests/unit/oauth-constants.test.ts | 38 ++ tests/unit/oauth-debug.test.ts | 69 ++++ tests/unit/oauth-discovery-metadata.test.ts | 38 ++ tests/unit/oauth-loopback-redirect.test.ts | 14 +- tests/unit/oauth-profile.test.ts | 83 +++- tests/unit/oauth-utils.test.ts | 22 +- tests/unit/provider-api.test.ts | 15 +- tests/unit/signin-callback-url.test.ts | 10 +- 73 files changed, 1922 insertions(+), 1613 deletions(-) create mode 100644 src/app/actions/auth.ts create mode 100644 src/components/sign-in-link.tsx create mode 100644 src/lib/oauth/constants.ts create mode 100644 src/lib/oauth/discovery-routes.ts delete mode 100644 src/lib/oauth/logging.ts delete mode 100644 src/lib/oauth/redirect.ts create mode 100644 tests/unit/auth-config.test.ts create mode 100644 tests/unit/debug-auth.test.ts create mode 100644 tests/unit/oauth-client-registration.test.ts create mode 100644 tests/unit/oauth-constants.test.ts create mode 100644 tests/unit/oauth-debug.test.ts create mode 100644 tests/unit/oauth-discovery-metadata.test.ts diff --git a/.env.example b/.env.example index c4ae9229..195af580 100644 --- a/.env.example +++ b/.env.example @@ -4,14 +4,10 @@ # Docker Compose app dev (`bun run dev:docker`) # overrides DATABASE_URL and S3 endpoint to use service DNS names. DATABASE_URL="postgresql://postgres:postgres@127.0.0.1:5432/life_ustc_dev" -JWT_SECRET="replace-with-random-secret" WEBHOOK_SECRET="replace-with-random-secret" AUTH_SECRET="replace-with-random-secret" APP_PUBLIC_ORIGIN="http://localhost:3000" APP_CANONICAL_ORIGIN="https://life-ustc.tiankaima.dev" -BETTER_AUTH_URL="http://localhost:3000" -# Optional dedicated key for encrypting OIDC client secrets at rest. -# OIDC_CLIENT_SECRET_ENCRYPTION_KEY="replace-with-random-secret" # Storage # These values also drive the shared MinIO defaults used by `docker-compose.dev.yml` @@ -45,7 +41,7 @@ AUTH_OIDC_CLIENT_SECRET="" # Dev-only defaults # UPLOAD_TOTAL_QUOTA_MB="1024" # DEV_DEBUG_USERNAME="dev-user" -# DEV_DEBUG_NAME="Dev Debug User" +# DEV_DEBUG_NAME="Dev User" # DEV_ADMIN_USERNAME="dev-admin" # DEV_ADMIN_NAME="Dev Admin User" # When E2E_DEBUG_AUTH=1 (e.g. Playwright), set both — no defaults in non-dev NODE_ENV: diff --git a/docs/features/oauth.json b/docs/features/oauth.json index 0f004959..32113bd7 100644 --- a/docs/features/oauth.json +++ b/docs/features/oauth.json @@ -22,6 +22,7 @@ "protected-resource-canonical-path": "For protected resources with a path (currently MCP resource /api/mcp), the canonical entry per RFC 9728 should be /.well-known/oauth-protected-resource/api/mcp; the root-level /.well-known/oauth-protected-resource serves only as a compatibility alias redirect to the canonical address, to avoid returning metadata inconsistent with the resource field.", "mcp-discovery-compatibility-aliases": "Because some MCP clients probe resource-relative or issuer-style well-known paths before settling on canonical metadata, /api/mcp/.well-known/oauth-authorization-server, /.well-known/oauth-authorization-server/api/mcp, /api/mcp/.well-known/openid-configuration, and /.well-known/openid-configuration/api/mcp should redirect to the issuer metadata used by the MCP protected resource.", "aliases-use-redirect": "Compatibility aliases should use redirects rather than redundantly returning a JSON that looks usable but is inconsistent with issuer/resource validation, so that clients ultimately complete metadata validation at the canonical address.", + "discovery-route-targets": "Discovery metadata and compatibility aliases are wired through a shared route-target table so canonical metadata and alias redirects stay consistent when paths are added or retired.", "discovery-cors": "Discovery metadata should support cross-origin reading; at minimum return Access-Control-Allow-Origin: * for OpenID discovery, and keep the CORS behavior of authorization server metadata and protected resource metadata consistent, reducing compatibility risks for browser-type clients and debugging tools.", "transport-cors": "The MCP transport endpoint /api/mcp should also support browser-based clients and debugging tools using Bearer tokens: answer OPTIONS preflights, allow the MCP-specific request headers, and expose MCP-Session-Id and WWW-Authenticate on cross-origin transport responses.", "transport-origin-validation": "Per the MCP Streamable HTTP transport guidance, /api/mcp should reject requests carrying an Origin header unless that origin matches the app's trusted origin set (public/canonical origin, localhost dev aliases, and allowed preview hosts). Non-browser clients that omit Origin remain supported." diff --git a/docs/features/user.json b/docs/features/user.json index acc625ec..8890612a 100644 --- a/docs/features/user.json +++ b/docs/features/user.json @@ -63,6 +63,7 @@ "display": { "fields": [ "callbackUrl query parameter (origin page)", + "Sign-in action links include the current path and query as callbackUrl", "Fallback to home page if no origin" ] } diff --git a/messages/en-us.json b/messages/en-us.json index 3b9f7770..878664e2 100644 --- a/messages/en-us.json +++ b/messages/en-us.json @@ -500,6 +500,9 @@ "searchShortcutHint": "Press ⌘K or Ctrl+K to focus search", "allSitesTab": "All websites", "overviewHint": "Overview shows 5 sites only: your pins first, then recommended ones.", + "viewMode": "Website view", + "gridView": "Grid", + "listView": "List", "pin": "Pin", "unpin": "Unpin", "pinFailedTitle": "Pin update failed", @@ -582,6 +585,7 @@ "startShort": "From", "endShort": "To", "empty": "No routes serve that stop pair in the selected direction. Try reversing the direction or picking another stop.", + "emptyReverseAction": "Reverse direction", "departIn": "Departs in about {count} minutes", "departEtaMinutes": "{count, plural, one {# minute} other {# minutes}}", "departEtaHours": "{count, plural, one {# hour} other {# hours}}", @@ -589,6 +593,8 @@ "etaUnknown": "ETA unavailable", "estimatedHint": "~ marks an estimated time inferred from nearby stops on the same trip.", "clientHint": "Day type, route matching, and ranking are computed in your browser from the raw timetable data.", + "direction": "Direction", + "routes": "Routes", "routeSectionsCount": "{count} route sections", "departureColumn": "Depart", "routeColumn": "Route timetable", @@ -1145,6 +1151,7 @@ "descriptionLabel": "Details", "descriptionPlaceholder": "Add requirements, submission format, and grading notes", "publishedAt": "Published", + "homeworkPublishedAt": "Homework published", "submissionStart": "Submission opens", "submissionDue": "Submission due", "helperPublishNow": "Publish now", @@ -1216,6 +1223,9 @@ "filterIncomplete": "Incomplete", "filterCompleted": "Completed", "filterAll": "All", + "viewMode": "Homework view", + "cardView": "Cards", + "listView": "List", "filterEmptyTitle": "No homework under this filter", "filterEmptyDescription": "Try another filter, or check back later.", "addButton": "Add homework", @@ -1566,6 +1576,9 @@ "previewScopeCount": "{count, plural, =0 {No scopes selected} one {# scope selected} other {# scopes selected}}", "existingClients": "Existing Clients", "existingClientsDescription": "Trusted first-party clients are separated from external and public clients so the admin inventory is easier to audit.", + "clientPageStatus": "Showing {start}-{end} of {total}", + "previousPage": "Previous", + "nextPage": "Next", "tableColumnScopes": "Scopes", "tableColumnRedirects": "Redirects", "tableColumnActions": "Actions", diff --git a/messages/zh-cn.json b/messages/zh-cn.json index d76b9641..19827601 100644 --- a/messages/zh-cn.json +++ b/messages/zh-cn.json @@ -500,6 +500,9 @@ "searchShortcutHint": "按 Ctrl+K 或 ⌘K 聚焦搜索", "allSitesTab": "全部网站", "overviewHint": "总览仅展示 5 个网站:优先展示你的置顶,其余按推荐补齐。", + "viewMode": "网站视图", + "gridView": "网格", + "listView": "列表", "pin": "置顶", "unpin": "取消置顶", "pinFailedTitle": "置顶更新失败", @@ -582,6 +585,7 @@ "startShort": "起", "endShort": "终", "empty": "当前方向下没有可用路线。可以尝试反向或重新选择站点。", + "emptyReverseAction": "反向查询", "departIn": "约 {count} 分钟后发车", "departEtaMinutes": "{count} 分钟", "departEtaHours": "{count} 小时", @@ -589,6 +593,8 @@ "etaUnknown": "到站时间未知", "estimatedHint": "~ 表示该时间由同班次相邻站点推算得出。", "clientHint": "工作日/周末判断、路线匹配和排序都在浏览器端基于原始时刻表完成。", + "direction": "方向", + "routes": "路线", "routeSectionsCount": "共 {count} 条路线分组", "departureColumn": "出发", "routeColumn": "路线时刻", @@ -1122,6 +1128,7 @@ "descriptionLabel": "说明", "descriptionPlaceholder": "补充作业要求、提交方式、评分规则等", "publishedAt": "发布日期", + "homeworkPublishedAt": "作业发布日期", "submissionStart": "提交开始", "submissionDue": "提交截止", "helperPublishNow": "立即发布", @@ -1193,6 +1200,9 @@ "filterIncomplete": "未完成", "filterCompleted": "已完成", "filterAll": "全部", + "viewMode": "作业视图", + "cardView": "卡片", + "listView": "列表", "filterEmptyTitle": "当前筛选下暂无作业", "filterEmptyDescription": "切换筛选条件,或稍后再试。", "addButton": "添加作业", @@ -1543,6 +1553,9 @@ "previewScopeCount": "{count, plural, =0 {尚未选择权限} one {已选择 # 个权限} other {已选择 # 个权限}}", "existingClients": "已有客户端", "existingClientsDescription": "把可信第一方客户端与外部 / 公共客户端分开展示,更方便后台审计。", + "clientPageStatus": "显示第 {start}-{end} 个,共 {total} 个", + "previousPage": "上一页", + "nextPage": "下一页", "tableColumnScopes": "权限范围", "tableColumnRedirects": "重定向", "tableColumnActions": "操作", diff --git a/src/app/.well-known/oauth-authorization-server/api/auth/route.ts b/src/app/.well-known/oauth-authorization-server/api/auth/route.ts index de911349..1396b60f 100644 --- a/src/app/.well-known/oauth-authorization-server/api/auth/route.ts +++ b/src/app/.well-known/oauth-authorization-server/api/auth/route.ts @@ -1,7 +1,4 @@ -import { - createDiscoveryMetadataRoute, - getAuthServerMetadataResponse, -} from "@/lib/oauth/discovery-metadata"; +import { createOAuthDiscoveryRoute } from "@/lib/oauth/discovery-routes"; export const dynamic = "force-dynamic"; @@ -9,6 +6,4 @@ export const dynamic = "force-dynamic"; * Canonical RFC 8414 authorization server metadata for issuer `/api/auth`. * @response 200 */ -export const { GET, OPTIONS } = createDiscoveryMetadataRoute( - getAuthServerMetadataResponse, -); +export const { GET, OPTIONS } = createOAuthDiscoveryRoute("authServerMetadata"); diff --git a/src/app/.well-known/oauth-authorization-server/api/mcp/route.ts b/src/app/.well-known/oauth-authorization-server/api/mcp/route.ts index 324aab31..0734c2ee 100644 --- a/src/app/.well-known/oauth-authorization-server/api/mcp/route.ts +++ b/src/app/.well-known/oauth-authorization-server/api/mcp/route.ts @@ -1,5 +1,4 @@ -import { getOAuthAuthorizationServerMetadataUrl } from "@/lib/mcp/urls"; -import { createDiscoveryRedirectRoute } from "@/lib/oauth/discovery-metadata"; +import { createOAuthDiscoveryRoute } from "@/lib/oauth/discovery-routes"; export const dynamic = "force-dynamic"; @@ -7,6 +6,4 @@ export const dynamic = "force-dynamic"; * Compatibility alias for clients that probe resource-path authorization-server metadata. * @response 307 */ -export const { GET, OPTIONS } = createDiscoveryRedirectRoute((request) => - getOAuthAuthorizationServerMetadataUrl(request), -); +export const { GET, OPTIONS } = createOAuthDiscoveryRoute("authServerAlias"); diff --git a/src/app/.well-known/oauth-authorization-server/route.ts b/src/app/.well-known/oauth-authorization-server/route.ts index 3bdabc39..d45a543c 100644 --- a/src/app/.well-known/oauth-authorization-server/route.ts +++ b/src/app/.well-known/oauth-authorization-server/route.ts @@ -1,5 +1,4 @@ -import { getOAuthAuthorizationServerMetadataUrl } from "@/lib/mcp/urls"; -import { createDiscoveryRedirectRoute } from "@/lib/oauth/discovery-metadata"; +import { createOAuthDiscoveryRoute } from "@/lib/oauth/discovery-routes"; export const dynamic = "force-dynamic"; /** @@ -7,6 +6,4 @@ export const dynamic = "force-dynamic"; * The canonical metadata URL for issuer `/api/auth` is path-specific. * @response 307 */ -export const { GET, OPTIONS } = createDiscoveryRedirectRoute(() => - getOAuthAuthorizationServerMetadataUrl(), -); +export const { GET, OPTIONS } = createOAuthDiscoveryRoute("authServerAlias"); diff --git a/src/app/.well-known/oauth-protected-resource/api/mcp/route.ts b/src/app/.well-known/oauth-protected-resource/api/mcp/route.ts index 50368b0e..c93a47fb 100644 --- a/src/app/.well-known/oauth-protected-resource/api/mcp/route.ts +++ b/src/app/.well-known/oauth-protected-resource/api/mcp/route.ts @@ -1,10 +1,4 @@ -import { NextResponse } from "next/server"; -import { getMcpServerUrl, getOAuthIssuerUrl } from "@/lib/mcp/urls"; -import { - createDiscoveryMetadataRoute, - getDiscoveryOptionsResponse, -} from "@/lib/oauth/discovery-metadata"; -import { MCP_TOOLS_SCOPE } from "@/lib/oauth/utils"; +import { createOAuthDiscoveryRoute } from "@/lib/oauth/discovery-routes"; export const dynamic = "force-dynamic"; @@ -12,35 +6,6 @@ export const dynamic = "force-dynamic"; * Canonical RFC 9728 protected resource metadata for MCP. * @response 200 */ -async function getProtectedResourceMetadataResponse(request: Request) { - const issuerUrl = getOAuthIssuerUrl(request); - - return NextResponse.json( - { - resource: getMcpServerUrl(request).toString(), - authorization_servers: [issuerUrl.toString()], - scopes_supported: [MCP_TOOLS_SCOPE], - bearer_methods_supported: ["header"], - resource_documentation: new URL("/api-docs", issuerUrl).toString(), - }, - { - headers: { - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "GET, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type, Authorization", - }, - }, - ); -} - -export const { GET } = createDiscoveryMetadataRoute( - getProtectedResourceMetadataResponse, +export const { GET, OPTIONS } = createOAuthDiscoveryRoute( + "protectedResourceMetadata", ); - -/** - * CORS preflight for protected resource metadata. - * @response 204 - */ -export function OPTIONS() { - return getDiscoveryOptionsResponse(); -} diff --git a/src/app/.well-known/oauth-protected-resource/route.ts b/src/app/.well-known/oauth-protected-resource/route.ts index 4a1ab97b..fc5bf94b 100644 --- a/src/app/.well-known/oauth-protected-resource/route.ts +++ b/src/app/.well-known/oauth-protected-resource/route.ts @@ -1,5 +1,4 @@ -import { getOAuthProtectedResourceMetadataUrl } from "@/lib/mcp/urls"; -import { createDiscoveryRedirectRoute } from "@/lib/oauth/discovery-metadata"; +import { createOAuthDiscoveryRoute } from "@/lib/oauth/discovery-routes"; export const dynamic = "force-dynamic"; @@ -8,6 +7,6 @@ export const dynamic = "force-dynamic"; * The canonical RFC 9728 URL is path-specific for resource `/api/mcp`. * @response 307 */ -export const { GET, OPTIONS } = createDiscoveryRedirectRoute(() => - getOAuthProtectedResourceMetadataUrl(), +export const { GET, OPTIONS } = createOAuthDiscoveryRoute( + "protectedResourceAlias", ); diff --git a/src/app/.well-known/openid-configuration/api/auth/route.ts b/src/app/.well-known/openid-configuration/api/auth/route.ts index 6dad123c..acee7e56 100644 --- a/src/app/.well-known/openid-configuration/api/auth/route.ts +++ b/src/app/.well-known/openid-configuration/api/auth/route.ts @@ -1,7 +1,4 @@ -import { - createDiscoveryMetadataRoute, - getOpenIdMetadataResponse, -} from "@/lib/oauth/discovery-metadata"; +import { createOAuthDiscoveryRoute } from "@/lib/oauth/discovery-routes"; export const dynamic = "force-dynamic"; @@ -9,6 +6,4 @@ export const dynamic = "force-dynamic"; * RFC 8414-compatible path form for OpenID provider metadata. * @response 200 */ -export const { GET, OPTIONS } = createDiscoveryMetadataRoute( - getOpenIdMetadataResponse, -); +export const { GET, OPTIONS } = createOAuthDiscoveryRoute("openIdMetadata"); diff --git a/src/app/.well-known/openid-configuration/api/mcp/route.ts b/src/app/.well-known/openid-configuration/api/mcp/route.ts index 98fc89d4..2a72fbfa 100644 --- a/src/app/.well-known/openid-configuration/api/mcp/route.ts +++ b/src/app/.well-known/openid-configuration/api/mcp/route.ts @@ -1,5 +1,4 @@ -import { getOAuthOpenIdConfigurationUrl } from "@/lib/mcp/urls"; -import { createDiscoveryRedirectRoute } from "@/lib/oauth/discovery-metadata"; +import { createOAuthDiscoveryRoute } from "@/lib/oauth/discovery-routes"; export const dynamic = "force-dynamic"; @@ -7,6 +6,4 @@ export const dynamic = "force-dynamic"; * Compatibility alias for clients that probe resource-path OIDC metadata. * @response 307 */ -export const { GET, OPTIONS } = createDiscoveryRedirectRoute((request) => - getOAuthOpenIdConfigurationUrl(request), -); +export const { GET, OPTIONS } = createOAuthDiscoveryRoute("openIdAlias"); diff --git a/src/app/.well-known/openid-configuration/route.ts b/src/app/.well-known/openid-configuration/route.ts index 84e55a25..f610e018 100644 --- a/src/app/.well-known/openid-configuration/route.ts +++ b/src/app/.well-known/openid-configuration/route.ts @@ -1,5 +1,4 @@ -import { getOAuthOpenIdConfigurationUrl } from "@/lib/mcp/urls"; -import { createDiscoveryRedirectRoute } from "@/lib/oauth/discovery-metadata"; +import { createOAuthDiscoveryRoute } from "@/lib/oauth/discovery-routes"; export const dynamic = "force-dynamic"; /** @@ -7,6 +6,4 @@ export const dynamic = "force-dynamic"; * The canonical OIDC discovery URL remains `{issuer}/.well-known/openid-configuration`. * @response 307 */ -export const { GET, OPTIONS } = createDiscoveryRedirectRoute(() => - getOAuthOpenIdConfigurationUrl(), -); +export const { GET, OPTIONS } = createOAuthDiscoveryRoute("openIdAlias"); diff --git a/src/app/actions/auth.ts b/src/app/actions/auth.ts new file mode 100644 index 00000000..0dbb3b4c --- /dev/null +++ b/src/app/actions/auth.ts @@ -0,0 +1,19 @@ +"use server"; + +import { revalidatePath } from "next/cache"; +import { signOut } from "@/auth"; +import { logServerActionError } from "@/lib/log/app-logger"; + +export async function signOutCurrentUser() { + try { + await signOut({ redirect: false }); + revalidatePath("/"); + + return { success: true }; + } catch (error) { + logServerActionError("Failed to sign out", error, { + action: "signOutCurrentUser", + }); + return { error: "Failed to sign out" }; + } +} diff --git a/src/app/actions/oauth.ts b/src/app/actions/oauth.ts index deb98327..0442ad1f 100644 --- a/src/app/actions/oauth.ts +++ b/src/app/actions/oauth.ts @@ -7,71 +7,83 @@ import { auth, authApi } from "@/auth"; import { Prisma } from "@/generated/prisma/client"; import { prisma } from "@/lib/db/prisma"; import { logServerActionError } from "@/lib/log/app-logger"; -import { resolveOAuthClientScopes } from "@/lib/oauth/client-registration"; -import { asOAuthProviderApi } from "@/lib/oauth/provider-api"; import { - DEFAULT_OAUTH_CLIENT_SCOPES, + resolveOAuthClientGrantTypes, + resolveOAuthClientScopes, +} from "@/lib/oauth/client-registration"; +import { + isSupportedOAuthClientAuthMethod, OAUTH_CLIENT_SECRET_BASIC_AUTH_METHOD, OAUTH_CLIENT_SECRET_POST_AUTH_METHOD, + OAUTH_CODE_RESPONSE_TYPE, OAUTH_PUBLIC_CLIENT_AUTH_METHOD, -} from "@/lib/oauth/utils"; + type SupportedOAuthClientAuthMethod, +} from "@/lib/oauth/constants"; +import { asOAuthProviderApi } from "@/lib/oauth/provider-api"; type CreateOAuthClientResult = | { error: string } | { success: true; clientId: string; clientSecret: string | null }; -function resolveAdminOAuthClientPattern(tokenEndpointAuthMethod: string) { - if (tokenEndpointAuthMethod === OAUTH_PUBLIC_CLIENT_AUTH_METHOD) { - return { - pattern: "public_pkce", - skipConsent: false, - enableEndSession: false, - } as const; - } - - if (tokenEndpointAuthMethod === OAUTH_CLIENT_SECRET_POST_AUTH_METHOD) { - return { - pattern: "confidential_connector", - skipConsent: false, - enableEndSession: false, - } as const; +const ADMIN_OAUTH_CLIENT_PATTERNS: Record< + SupportedOAuthClientAuthMethod, + { + pattern: "public_pkce" | "confidential_connector" | "trusted_first_party"; + skipConsent: boolean; + enableEndSession: boolean; } - - return { +> = { + [OAUTH_PUBLIC_CLIENT_AUTH_METHOD]: { + pattern: "public_pkce", + skipConsent: false, + enableEndSession: false, + }, + [OAUTH_CLIENT_SECRET_POST_AUTH_METHOD]: { + pattern: "confidential_connector", + skipConsent: false, + enableEndSession: false, + }, + [OAUTH_CLIENT_SECRET_BASIC_AUTH_METHOD]: { pattern: "trusted_first_party", skipConsent: true, enableEndSession: true, - } as const; + }, +} as const; + +function resolveAdminOAuthClientPattern(tokenEndpointAuthMethod: string) { + return isSupportedOAuthClientAuthMethod(tokenEndpointAuthMethod) + ? ADMIN_OAUTH_CLIENT_PATTERNS[tokenEndpointAuthMethod] + : null; +} + +function nonEmptyString(value: unknown): string | null { + return typeof value === "string" && value.trim().length > 0 ? value : null; } function getOAuthActionErrorMessage(error: unknown, fallback: string) { - if (error instanceof Error && error.message.trim().length > 0) { - return error.message; + const errorMessage = + error instanceof Error ? nonEmptyString(error.message) : null; + if (errorMessage) { + return errorMessage; } if (error && typeof error === "object") { const record = error as Record; - if ( - typeof record.message === "string" && - record.message.trim().length > 0 - ) { - return record.message; + const recordMessage = nonEmptyString(record.message); + if (recordMessage) { + return recordMessage; } const body = record.body; if (body && typeof body === "object") { const bodyRecord = body as Record; - if ( - typeof bodyRecord.error_description === "string" && - bodyRecord.error_description.trim().length > 0 - ) { - return bodyRecord.error_description; + const errorDescription = nonEmptyString(bodyRecord.error_description); + if (errorDescription) { + return errorDescription; } - if ( - typeof bodyRecord.message === "string" && - bodyRecord.message.trim().length > 0 - ) { - return bodyRecord.message; + const bodyMessage = nonEmptyString(bodyRecord.message); + if (bodyMessage) { + return bodyMessage; } } } @@ -120,21 +132,16 @@ export async function createOAuthClient( .map((s) => s.trim()) .filter(Boolean); - const scopesResult = resolveOAuthClientScopes({ - defaultScopes: [...DEFAULT_OAUTH_CLIENT_SCOPES], - requestedScopes: requestedScopes.length > 0 ? requestedScopes : undefined, - }); + const scopesResult = resolveOAuthClientScopes( + requestedScopes.length > 0 ? requestedScopes : undefined, + ); if ("error" in scopesResult) { return { error: scopesResult.error }; } const scopes = scopesResult.scopes; const clientPattern = resolveAdminOAuthClientPattern(tokenEndpointAuthMethod); - if ( - tokenEndpointAuthMethod !== OAUTH_CLIENT_SECRET_BASIC_AUTH_METHOD && - tokenEndpointAuthMethod !== OAUTH_CLIENT_SECRET_POST_AUTH_METHOD && - tokenEndpointAuthMethod !== OAUTH_PUBLIC_CLIENT_AUTH_METHOD - ) { + if (!clientPattern) { return { error: "Unsupported token endpoint auth method" }; } @@ -145,10 +152,8 @@ export async function createOAuthClient( client_name: name, redirect_uris: redirectUris, token_endpoint_auth_method: tokenEndpointAuthMethod, - grant_types: scopes.includes("offline_access") - ? ["authorization_code", "refresh_token"] - : ["authorization_code"], - response_types: ["code"], + grant_types: resolveOAuthClientGrantTypes(scopes), + response_types: [OAUTH_CODE_RESPONSE_TYPE], scope: scopes.join(" "), require_pkce: true, skip_consent: clientPattern.skipConsent, diff --git a/src/app/admin/oauth/oauth-client-create-dialog.tsx b/src/app/admin/oauth/oauth-client-create-dialog.tsx index bf945e3d..3c28a17b 100644 --- a/src/app/admin/oauth/oauth-client-create-dialog.tsx +++ b/src/app/admin/oauth/oauth-client-create-dialog.tsx @@ -20,7 +20,6 @@ import { cn } from "@/lib/utils"; import { AUTH_METHOD_OPTIONS, getAuthMethodOption, - getClientTypeBadgeVariant, getScopeInputId, type OAuthTranslator, parseRedirectUris, @@ -88,7 +87,7 @@ export function OAuthClientCreateDialog({ htmlFor={inputId} key={option.value} className={cn( - "cursor-pointer rounded-2xl border p-4 transition-colors", + "cursor-pointer rounded-lg border p-4 transition-colors", checked ? option.accentClassName : "border-border bg-card/72 hover:bg-accent/40", @@ -106,7 +105,7 @@ export function OAuthClientCreateDialog({
{t(option.strategyTitleKey)} - + {t(option.labelKey)}
diff --git a/src/app/admin/oauth/oauth-client-list.tsx b/src/app/admin/oauth/oauth-client-list.tsx index ccaf2337..3d8f8ee3 100644 --- a/src/app/admin/oauth/oauth-client-list.tsx +++ b/src/app/admin/oauth/oauth-client-list.tsx @@ -1,6 +1,7 @@ "use client"; -import { Copy } from "lucide-react"; +import { ChevronLeft, ChevronRight, Copy } from "lucide-react"; +import { useState } from "react"; import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { @@ -13,13 +14,14 @@ import { import { Field, FieldLabel } from "@/components/ui/field"; import { type CreatedCredentials, - getClientPatternDescriptionKey, getClientTypeBadgeVariant, getClientTypeLabel, type OAuthClientInfo, type OAuthTranslator, } from "./oauth-client-manager-shared"; +const CLIENTS_PER_SECTION_PAGE = 3; + function CopyField({ label, value, @@ -36,20 +38,23 @@ function CopyField({ return ( {label} -
+
+

+ {dateTimeFormatter.format(new Date(client.createdAt))} +

-
-
-

+

+
+ {t("clientIdLabel")} -

-
- - {client.clientId} - - -
+
+ + {client.clientId} + +
-
-

+

+

{t("tableColumnRedirects")}

{redirectUris.length === 0 ? ( -

+

) : ( -
+
{redirectUris.map((redirectUri) => (
- {redirectUri} - +
-
-
-

- {t("tableColumnScopes")} +

+ {client.scopes + .map((scope) => t(`scope_${scope}`, { fallback: scope })) + .join(" · ")}

-
- {client.scopes.map((scope) => ( - - - {t(`scope_${scope}`, { fallback: scope })} - - - ))} -
- - + + +
+ ); } @@ -225,36 +211,94 @@ function OAuthClientSection({ description: string; t: OAuthTranslator; }) { + const [page, setPage] = useState(1); + const totalPages = Math.max( + 1, + Math.ceil(clients.length / CLIENTS_PER_SECTION_PAGE), + ); + const currentPage = Math.min(page, totalPages); + const startIndex = (currentPage - 1) * CLIENTS_PER_SECTION_PAGE; + const visibleClients = clients.slice( + startIndex, + startIndex + CLIENTS_PER_SECTION_PAGE, + ); + const showingStart = clients.length === 0 ? 0 : startIndex + 1; + const showingEnd = Math.min( + startIndex + visibleClients.length, + clients.length, + ); + return ( - - -
- {title} +
+
+
+

{title}

{clients.length}
- {description} - - +

+ {description} +

+
+
{clients.length === 0 ? ( -
+
{t(emptyKey)}
) : ( - clients.map((client) => ( - - )) + <> +
+ {visibleClients.map((client) => ( + + ))} +
+ {totalPages > 1 ? ( +
+

+ {t("clientPageStatus", { + start: showingStart, + end: showingEnd, + total: clients.length, + })} +

+
+ + +
+
+ ) : null} + )} - - +
+
); } @@ -270,7 +314,7 @@ export function CreatedCredentialsCard({ t: OAuthTranslator; }) { return ( - + {t("credentialsTitle")} {t("credentialsWarning")} @@ -374,22 +418,22 @@ export function OAuthClientList({ }) { if (clients.length === 0) { return ( - - - {t("existingClients")} - {t("existingClientsDescription")} - - -
- {t("noClients")} -
-
-
+
+

+ {t("existingClients")} +

+

+ {t("existingClientsDescription")} +

+
+ {t("noClients")} +
+
); } return ( -
+
string; export type AuthMethodOption = { - value: string; + value: SupportedOAuthClientAuthMethod; icon: LucideIcon; labelKey: string; descriptionKey: string; @@ -43,6 +52,7 @@ export type AuthMethodOption = { strategyHintKey: string; accentClassName: string; accentIconClassName: string; + badgeVariant: "info" | "success" | "warning"; }; export const AUTH_METHOD_OPTIONS: AuthMethodOption[] = [ @@ -54,10 +64,9 @@ export const AUTH_METHOD_OPTIONS: AuthMethodOption[] = [ strategyTitleKey: "strategyFirstPartyTitle", strategyDescriptionKey: "strategyFirstPartyDescription", strategyHintKey: "strategyFirstPartyHint", - accentClassName: - "border-sky-500/24 bg-sky-500/[0.08] text-sky-800 dark:text-sky-200", - accentIconClassName: - "border-sky-500/24 bg-sky-500/[0.12] text-sky-700 dark:text-sky-200", + accentClassName: "border-foreground bg-muted/45 text-foreground", + accentIconClassName: "border-border bg-background text-foreground", + badgeVariant: "info", }, { value: OAUTH_PUBLIC_CLIENT_AUTH_METHOD, @@ -67,10 +76,9 @@ export const AUTH_METHOD_OPTIONS: AuthMethodOption[] = [ strategyTitleKey: "strategyPublicTitle", strategyDescriptionKey: "strategyPublicDescription", strategyHintKey: "strategyPublicHint", - accentClassName: - "border-emerald-500/24 bg-emerald-500/[0.08] text-emerald-800 dark:text-emerald-200", - accentIconClassName: - "border-emerald-500/24 bg-emerald-500/[0.12] text-emerald-700 dark:text-emerald-200", + accentClassName: "border-foreground bg-muted/45 text-foreground", + accentIconClassName: "border-border bg-background text-foreground", + badgeVariant: "success", }, { value: OAUTH_CLIENT_SECRET_POST_AUTH_METHOD, @@ -80,20 +88,19 @@ export const AUTH_METHOD_OPTIONS: AuthMethodOption[] = [ strategyTitleKey: "strategyAdvancedTitle", strategyDescriptionKey: "strategyAdvancedDescription", strategyHintKey: "strategyAdvancedHint", - accentClassName: - "border-amber-500/24 bg-amber-500/[0.08] text-amber-800 dark:text-amber-100", - accentIconClassName: - "border-amber-500/24 bg-amber-500/[0.12] text-amber-700 dark:text-amber-100", + accentClassName: "border-foreground bg-muted/45 text-foreground", + accentIconClassName: "border-border bg-background text-foreground", + badgeVariant: "warning", }, ]; export const SCOPE_OPTIONS = [ { - value: "openid", + value: OAUTH_OPENID_SCOPE, descriptionKey: "scopeOpenIdDescription", }, { - value: "profile", + value: OAUTH_PROFILE_SCOPE, descriptionKey: "scopeProfileDescription", }, { @@ -103,23 +110,11 @@ export const SCOPE_OPTIONS = [ ] as const; export function getClientTypeBadgeVariant(method: string) { - if (method === OAUTH_PUBLIC_CLIENT_AUTH_METHOD) { - return "success" as const; - } - if (method === OAUTH_CLIENT_SECRET_POST_AUTH_METHOD) { - return "warning" as const; - } - return "info" as const; + return getAuthMethodOption(method).badgeVariant; } export function getClientTypeLabel(t: OAuthTranslator, method: string) { - if (method === OAUTH_PUBLIC_CLIENT_AUTH_METHOD) { - return t("clientTypePublic"); - } - if (method === OAUTH_CLIENT_SECRET_POST_AUTH_METHOD) { - return t("clientTypeConfidentialPost"); - } - return t("clientTypeConfidentialBasic"); + return t(getAuthMethodOption(method).labelKey); } export function getScopeInputId(scope: string) { @@ -140,14 +135,22 @@ export function getAuthMethodOption(value: string) { ); } +export function isTrustedClientAuthMethod(method: string) { + return method === OAUTH_CLIENT_SECRET_BASIC_AUTH_METHOD; +} + +export function isPublicClientAuthMethod(method: string) { + return method === OAUTH_PUBLIC_CLIENT_AUTH_METHOD; +} + export function getClientPatternDescriptionKey(client: OAuthClientInfo) { if (client.isTrusted) { return "clientKindTrustedDescription"; } - if (client.tokenEndpointAuthMethod === OAUTH_PUBLIC_CLIENT_AUTH_METHOD) { + if (isPublicClientAuthMethod(client.tokenEndpointAuthMethod)) { return "clientKindPublicDescription"; } return "clientKindExternalDescription"; } -export const authMethodLeadIcon = KeyRound; +export const AuthMethodLeadIcon = KeyRound; diff --git a/src/app/admin/oauth/oauth-client-manager.tsx b/src/app/admin/oauth/oauth-client-manager.tsx index b7421df1..aad2c6a3 100644 --- a/src/app/admin/oauth/oauth-client-manager.tsx +++ b/src/app/admin/oauth/oauth-client-manager.tsx @@ -11,8 +11,8 @@ import { type CreatedCredentials, DEFAULT_AUTH_METHOD, DEFAULT_SCOPE_VALUES, - OAUTH_CLIENT_SECRET_BASIC_AUTH_METHOD, - OAUTH_PUBLIC_CLIENT_AUTH_METHOD, + isPublicClientAuthMethod, + isTrustedClientAuthMethod, type OAuthClientInfo, parseRedirectUris, } from "./oauth-client-manager-shared"; @@ -43,9 +43,8 @@ export function OAuthClientManager({ ); const publicClients = useMemo( () => - clients.filter( - (client) => - client.tokenEndpointAuthMethod === OAUTH_PUBLIC_CLIENT_AUTH_METHOD, + clients.filter((client) => + isPublicClientAuthMethod(client.tokenEndpointAuthMethod), ), [clients], ); @@ -96,8 +95,7 @@ export function OAuthClientManager({ .getAll("scopes") .map((value) => String(value).trim()) .filter(Boolean); - const isTrusted = - tokenEndpointAuthMethod === OAUTH_CLIENT_SECRET_BASIC_AUTH_METHOD; + const isTrusted = isTrustedClientAuthMethod(tokenEndpointAuthMethod); setLoading(true); const result = await createOAuthClient(formData); diff --git a/src/app/admin/oauth/oauth-client-overview.tsx b/src/app/admin/oauth/oauth-client-overview.tsx index 8d8e7ddb..553df1d2 100644 --- a/src/app/admin/oauth/oauth-client-overview.tsx +++ b/src/app/admin/oauth/oauth-client-overview.tsx @@ -1,21 +1,9 @@ "use client"; import { ShieldCheck } from "lucide-react"; -import { PageStatCard, PageStatGrid } from "@/components/page-layout"; -import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { - Card, - CardDescription, - CardHeader, - CardPanel, - CardTitle, -} from "@/components/ui/card"; -import { cn } from "@/lib/utils"; -import { - AUTH_METHOD_OPTIONS, - authMethodLeadIcon, - getClientTypeBadgeVariant, + AuthMethodLeadIcon, type OAuthTranslator, } from "./oauth-client-manager-shared"; @@ -32,114 +20,39 @@ export function OAuthClientOverview({ onOpenCreateDialog: (method?: string) => void; t: OAuthTranslator; }) { - const LeadIcon = authMethodLeadIcon; - return ( - <> - - -
-
- - Better Auth OAuth Provider -
-
-

- {t("panelGuideTitle")} -

-

- {t("panelGuideDescription")} -

-
-
- {t("strategyFirstPartyTitle")} - {t("strategyPublicTitle")} - {t("strategyAdvancedTitle")} -
+
+
+
+
+ + + Better Auth OAuth Provider +
- -
-

{t("createClient")}

-

- {t("createClientHint")} -

- -

- {t("createClientFootnote")} +

+

+ {t("panelGuideTitle")} +

+

+ {t("panelGuideDescription")}

- - - - - - - - - - - - {t("strategyTitle")} - {t("strategyDescription")} - - - {AUTH_METHOD_OPTIONS.map((option) => { - const Icon = option.icon; +

+ {t("overviewClients")}: {clientCount} · {t("overviewTrusted")}:{" "} + {trustedCount} · {t("overviewPublic")}: {publicCount} +

+
- return ( - - -
- -
-
-
-

- {t(option.strategyTitleKey)} -

- - {t(option.labelKey)} - -
-

- {t(option.strategyDescriptionKey)} -

-

- {t(option.strategyHintKey)} -

-
- -
-
- ); - })} - - - + +
+
); } diff --git a/src/app/admin/oauth/page.tsx b/src/app/admin/oauth/page.tsx index 46d2721b..b89f249a 100644 --- a/src/app/admin/oauth/page.tsx +++ b/src/app/admin/oauth/page.tsx @@ -4,6 +4,7 @@ import { getTranslations } from "next-intl/server"; import { PageBreadcrumbs, PageLayout } from "@/components/page-layout"; import { requireAdminPage } from "@/lib/admin-utils"; import { prisma } from "@/lib/db/prisma"; +import { OAUTH_CLIENT_SECRET_BASIC_AUTH_METHOD } from "@/lib/oauth/constants"; import { toShanghaiIsoString } from "@/lib/time/serialize-date-output"; import { OAuthClientManager } from "./oauth-client-manager"; @@ -62,7 +63,7 @@ export default async function AdminOAuthPage() { clientId: c.clientId, name: c.name ?? "", tokenEndpointAuthMethod: - c.tokenEndpointAuthMethod ?? "client_secret_basic", + c.tokenEndpointAuthMethod ?? OAUTH_CLIENT_SECRET_BASIC_AUTH_METHOD, redirectUris: c.redirectUris, scopes: c.scopes, isTrusted: Boolean(c.skipConsent), diff --git a/src/app/api/auth/.well-known/openid-configuration/route.ts b/src/app/api/auth/.well-known/openid-configuration/route.ts index 76f74dcf..e42952fb 100644 --- a/src/app/api/auth/.well-known/openid-configuration/route.ts +++ b/src/app/api/auth/.well-known/openid-configuration/route.ts @@ -1,7 +1,4 @@ -import { - createDiscoveryMetadataRoute, - getOpenIdMetadataResponse, -} from "@/lib/oauth/discovery-metadata"; +import { createOAuthDiscoveryRoute } from "@/lib/oauth/discovery-routes"; export const dynamic = "force-dynamic"; @@ -9,6 +6,4 @@ export const dynamic = "force-dynamic"; * Canonical OpenID Connect Discovery metadata for issuer `/api/auth`. * @response 200 */ -export const { GET, OPTIONS } = createDiscoveryMetadataRoute( - getOpenIdMetadataResponse, -); +export const { GET, OPTIONS } = createOAuthDiscoveryRoute("openIdMetadata"); diff --git a/src/app/api/auth/oauth2/token/route.ts b/src/app/api/auth/oauth2/token/route.ts index 390c79b0..ceab75a1 100644 --- a/src/app/api/auth/oauth2/token/route.ts +++ b/src/app/api/auth/oauth2/token/route.ts @@ -8,6 +8,7 @@ import { summarizeOAuthRedirectUri, withBetterAuthOAuthDebug, } from "@/lib/log/oauth-debug"; +import { OAUTH_DEVICE_CODE_GRANT_TYPE } from "@/lib/oauth/constants"; import { DEVICE_CODE_ERRORS, DEVICE_CODE_POLL_INTERVAL, @@ -18,7 +19,8 @@ import { hashOAuthClientSecretForDbStorage } from "@/lib/oauth/utils"; export const dynamic = "force-dynamic"; -const DEVICE_CODE_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:device_code"; +const DEVICE_ACCESS_TOKEN_EXPIRES_IN = 3600; +const DEVICE_REFRESH_TOKEN_EXPIRES_IN = 30 * 24 * 3600; function deviceCodeError(error: string, status = 400) { return jsonResponse({ error }, { status }); @@ -96,7 +98,6 @@ async function handleDeviceCodeGrant( return deviceCodeError("invalid_request"); } - // Find the device code record const record = await prisma.deviceCode.findUnique({ where: { deviceCode }, select: { @@ -118,16 +119,13 @@ async function handleDeviceCodeGrant( return deviceCodeError("invalid_client"); } - // Check expiry if (record.expiresAt < new Date()) { return deviceCodeError(DEVICE_CODE_ERRORS.EXPIRED_TOKEN); } - // Check polling rate (slow_down) if (record.lastPolledAt) { const elapsed = Date.now() - record.lastPolledAt.getTime(); if (elapsed < DEVICE_CODE_POLL_INTERVAL * 1000) { - // Update lastPolledAt even on slow_down await prisma.deviceCode.update({ where: { id: record.id }, data: { lastPolledAt: new Date() }, @@ -136,13 +134,11 @@ async function handleDeviceCodeGrant( } } - // Update lastPolledAt await prisma.deviceCode.update({ where: { id: record.id }, data: { lastPolledAt: new Date() }, }); - // Check status if (record.status === DEVICE_CODE_STATUS.DENIED) { return deviceCodeError(DEVICE_CODE_ERRORS.ACCESS_DENIED); } @@ -151,20 +147,22 @@ async function handleDeviceCodeGrant( return deviceCodeError(DEVICE_CODE_ERRORS.AUTHORIZATION_PENDING); } - // Status is APPROVED - issue tokens if (!record.userId) { return deviceCodeError("server_error", 500); } const userId = record.userId; - // Generate opaque access token and refresh token const accessTokenPlain = randomBytes(32).toString("base64url"); const refreshTokenPlain = randomBytes(32).toString("base64url"); const accessTokenHash = hashOAuthClientSecretForDbStorage(accessTokenPlain); const refreshTokenHash = hashOAuthClientSecretForDbStorage(refreshTokenPlain); - const accessExpiresAt = new Date(Date.now() + 3600 * 1000); // 1 hour - const refreshExpiresAt = new Date(Date.now() + 30 * 24 * 3600 * 1000); // 30 days + const accessExpiresAt = new Date( + Date.now() + DEVICE_ACCESS_TOKEN_EXPIRES_IN * 1000, + ); + const refreshExpiresAt = new Date( + Date.now() + DEVICE_REFRESH_TOKEN_EXPIRES_IN * 1000, + ); const issued = await prisma.$transaction(async (tx) => { const claimed = await tx.deviceCode.deleteMany({ @@ -213,14 +211,13 @@ async function handleDeviceCodeGrant( return jsonResponse({ access_token: accessTokenPlain, token_type: "Bearer", - expires_in: 3600, + expires_in: DEVICE_ACCESS_TOKEN_EXPIRES_IN, refresh_token: refreshTokenPlain, scope: record.scopes.join(" "), }); } export async function POST(request: Request) { - // Clone request to read body without consuming it const cloned = request.clone(); let params: URLSearchParams; @@ -232,13 +229,12 @@ export async function POST(request: Request) { return withBetterAuthOAuthDebug("POST", request, handlers.POST); } - if (params.get("grant_type") === DEVICE_CODE_GRANT_TYPE) { + if (params.get("grant_type") === OAUTH_DEVICE_CODE_GRANT_TYPE) { return handleDeviceCodeGrant(request, params); } logObservedTokenRedirectRequest(request, params); - // Delegate all other grant types to Better Auth return withBetterAuthOAuthDebug( "POST", await maybeNormalizeTokenLoopbackRedirectRequest(request, params), @@ -246,7 +242,6 @@ export async function POST(request: Request) { ); } -// GET is not used for token endpoint but delegate just in case export function GET(request: Request) { return withBetterAuthOAuthDebug("GET", request, handlers.GET); } diff --git a/src/app/api/mcp/.well-known/oauth-authorization-server/route.ts b/src/app/api/mcp/.well-known/oauth-authorization-server/route.ts index a8847a89..76218997 100644 --- a/src/app/api/mcp/.well-known/oauth-authorization-server/route.ts +++ b/src/app/api/mcp/.well-known/oauth-authorization-server/route.ts @@ -1,5 +1,4 @@ -import { getOAuthAuthorizationServerMetadataUrl } from "@/lib/mcp/urls"; -import { createDiscoveryRedirectRoute } from "@/lib/oauth/discovery-metadata"; +import { createOAuthDiscoveryRoute } from "@/lib/oauth/discovery-routes"; export const dynamic = "force-dynamic"; @@ -7,6 +6,4 @@ export const dynamic = "force-dynamic"; * Compatibility alias for clients that probe authorization-server metadata relative to /api/mcp. * @response 307 */ -export const { GET, OPTIONS } = createDiscoveryRedirectRoute((request) => - getOAuthAuthorizationServerMetadataUrl(request), -); +export const { GET, OPTIONS } = createOAuthDiscoveryRoute("authServerAlias"); diff --git a/src/app/api/mcp/.well-known/openid-configuration/route.ts b/src/app/api/mcp/.well-known/openid-configuration/route.ts index 6696a804..b8c90f7f 100644 --- a/src/app/api/mcp/.well-known/openid-configuration/route.ts +++ b/src/app/api/mcp/.well-known/openid-configuration/route.ts @@ -1,5 +1,4 @@ -import { getOAuthOpenIdConfigurationUrl } from "@/lib/mcp/urls"; -import { createDiscoveryRedirectRoute } from "@/lib/oauth/discovery-metadata"; +import { createOAuthDiscoveryRoute } from "@/lib/oauth/discovery-routes"; export const dynamic = "force-dynamic"; @@ -7,6 +6,4 @@ export const dynamic = "force-dynamic"; * Compatibility alias for clients that probe OIDC metadata relative to /api/mcp. * @response 307 */ -export const { GET, OPTIONS } = createDiscoveryRedirectRoute((request) => - getOAuthOpenIdConfigurationUrl(request), -); +export const { GET, OPTIONS } = createOAuthDiscoveryRoute("openIdAlias"); diff --git a/src/app/oauth/device/actions.ts b/src/app/oauth/device/actions.ts index d4643266..7e3f8bb8 100644 --- a/src/app/oauth/device/actions.ts +++ b/src/app/oauth/device/actions.ts @@ -4,61 +4,44 @@ import { redirect } from "next/navigation"; import { auth } from "@/auth"; import { buildSignInRedirectUrl } from "@/lib/auth/auth-routing"; import { prisma } from "@/lib/db/prisma"; +import { buildSearchParams } from "@/lib/navigation/search-params"; import { DEVICE_CODE_STATUS, normalizeUserCode } from "@/lib/oauth/device-code"; +type DevicePageUrlParams = { + code?: string; + step?: "approve"; + result?: "approved" | "denied" | "error"; + reason?: "missing_code" | "invalid_or_expired"; +}; + +function buildDevicePageUrl(values: DevicePageUrlParams = {}) { + const query = buildSearchParams({ values }); + return query ? `/oauth/device?${query}` : "/oauth/device"; +} + function buildDeviceCallbackUrl(rawCode: FormDataEntryValue | null) { if (typeof rawCode !== "string" || !rawCode.trim()) { - return "/oauth/device"; + return buildDevicePageUrl(); } - return `/oauth/device?${new URLSearchParams({ + return buildDevicePageUrl({ code: rawCode.trim(), step: "approve", - }).toString()}`; + }); } export async function approveDeviceCode(formData: FormData) { - const session = await auth(); - if (!session?.user?.id) { - redirect( - buildSignInRedirectUrl( - {}, - buildDeviceCallbackUrl(formData.get("userCode")), - ), - ); - } - - const rawCode = formData.get("userCode"); - if (typeof rawCode !== "string" || !rawCode.trim()) { - redirect("/oauth/device?result=error&reason=missing_code"); - } - - const userCode = normalizeUserCode(rawCode); - - const record = await prisma.deviceCode.findUnique({ - where: { userCode }, - }); - - if ( - !record || - record.status !== DEVICE_CODE_STATUS.PENDING || - record.expiresAt < new Date() - ) { - redirect("/oauth/device?result=error&reason=invalid_or_expired"); - } - - await prisma.deviceCode.update({ - where: { id: record.id }, - data: { - status: DEVICE_CODE_STATUS.APPROVED, - userId: session.user.id, - }, - }); - - redirect("/oauth/device?result=approved"); + await completeDeviceCodeDecision(formData, "approve"); } export async function denyDeviceCode(formData: FormData) { + await completeDeviceCodeDecision(formData, "deny"); +} + +async function completeDeviceCodeDecision( + formData: FormData, + decision: "approve" | "deny", +) { const session = await auth(); if (!session?.user?.id) { redirect( @@ -71,7 +54,7 @@ export async function denyDeviceCode(formData: FormData) { const rawCode = formData.get("userCode"); if (typeof rawCode !== "string" || !rawCode.trim()) { - redirect("/oauth/device?result=error&reason=missing_code"); + redirect(buildDevicePageUrl({ result: "error", reason: "missing_code" })); } const userCode = normalizeUserCode(rawCode); @@ -85,13 +68,22 @@ export async function denyDeviceCode(formData: FormData) { record.status !== DEVICE_CODE_STATUS.PENDING || record.expiresAt < new Date() ) { - redirect("/oauth/device?result=error&reason=invalid_or_expired"); + redirect( + buildDevicePageUrl({ result: "error", reason: "invalid_or_expired" }), + ); } + const approved = decision === "approve"; + await prisma.deviceCode.update({ where: { id: record.id }, - data: { status: DEVICE_CODE_STATUS.DENIED }, + data: approved + ? { + status: DEVICE_CODE_STATUS.APPROVED, + userId: session.user.id, + } + : { status: DEVICE_CODE_STATUS.DENIED }, }); - redirect("/oauth/device?result=denied"); + redirect(buildDevicePageUrl({ result: approved ? "approved" : "denied" })); } diff --git a/src/app/oauth/device/page.tsx b/src/app/oauth/device/page.tsx index 40e990e9..4842aa36 100644 --- a/src/app/oauth/device/page.tsx +++ b/src/app/oauth/device/page.tsx @@ -1,14 +1,60 @@ import type { Metadata } from "next"; import { redirect } from "next/navigation"; +import type { ReactNode } from "react"; import { auth } from "@/auth"; import { buildSignInRedirectUrl } from "@/lib/auth/auth-routing"; import { prisma } from "@/lib/db/prisma"; +import { buildSearchParams } from "@/lib/navigation/search-params"; import { DEVICE_CODE_STATUS, normalizeUserCode } from "@/lib/oauth/device-code"; import { approveDeviceCode, denyDeviceCode } from "./actions"; export const metadata: Metadata = { title: "Device Login" }; export const dynamic = "force-dynamic"; +function DevicePanel({ + children, + textCenter = false, +}: { + children: ReactNode; + textCenter?: boolean; +}) { + return ( +
+
+ {children} +
+
+ ); +} + +function DeviceMessage({ + title, + children, + showRetry = false, +}: { + title: string; + children: ReactNode; + showRetry?: boolean; +}) { + return ( + +

{title}

+

{children}

+ {showRetry ? ( + + Try again + + ) : null} +
+ ); +} + export default async function DeviceVerifyPage({ searchParams, }: { @@ -22,22 +68,16 @@ export default async function DeviceVerifyPage({ const [session, params] = await Promise.all([auth(), searchParams]); if (!session?.user?.id) { - const callbackParams = new URLSearchParams(); - if (params.code) { - callbackParams.set("code", params.code); - } - if (params.step) { - callbackParams.set("step", params.step); - } - if (params.result) { - callbackParams.set("result", params.result); - } - if (params.reason) { - callbackParams.set("reason", params.reason); - } - - const callbackUrl = callbackParams.size - ? `/oauth/device?${callbackParams.toString()}` + const callbackQuery = buildSearchParams({ + values: { + code: params.code, + step: params.step, + result: params.result, + reason: params.reason, + }, + }); + const callbackUrl = callbackQuery + ? `/oauth/device?${callbackQuery}` : "/oauth/device"; redirect(buildSignInRedirectUrl({}, callbackUrl)); } @@ -45,43 +85,41 @@ export default async function DeviceVerifyPage({ // Result screen after approve/deny if (params.result) { return ( -
-
- {params.result === "approved" && ( - <> -

Device Approved

-

- You have authorized the device. You can close this page. -

- - )} - {params.result === "denied" && ( - <> -

Device Denied

-

- The device login request was denied. -

- - )} - {params.result === "error" && ( - <> -

- Error -

-

- {params.reason === "missing_code" && - "No device code was provided."} - {params.reason === "invalid_or_expired" && - "The device code is invalid or has expired."} - {!params.reason && "An unknown error occurred."} -

- - Try again - - - )} -
-
+ + {params.result === "approved" && ( + <> +

Device Approved

+

+ You have authorized the device. You can close this page. +

+ + )} + {params.result === "denied" && ( + <> +

Device Denied

+

+ The device login request was denied. +

+ + )} + {params.result === "error" && ( + <> +

+ Error +

+

+ {params.reason === "missing_code" && + "No device code was provided."} + {params.reason === "invalid_or_expired" && + "The device code is invalid or has expired."} + {!params.reason && "An unknown error occurred."} +

+ + Try again + + + )} +
); } @@ -95,135 +133,104 @@ export default async function DeviceVerifyPage({ if (!record) { return ( -
-
-

- Code Not Found -

-

- No device login request matches this code. -

- - Try again - -
-
+ + No device login request matches this code. + ); } if (record.expiresAt < new Date()) { return ( -
-
-

- Code Expired -

-

- This device code has expired. Please start a new login on your - device. -

-
-
+ + This device code has expired. Please start a new login on your device. + ); } if (record.status !== DEVICE_CODE_STATUS.PENDING) { return ( -
-
-

- Code Already Used -

-

- This device code has already been {record.status}. -

-
-
+ + This device code has already been {record.status}. + ); } const clientName = record.client.name ?? record.client.clientId; return ( -
-
-

- Device Login -

-

- {clientName} is requesting access to your account. -

- - {record.scopes.length > 0 && ( -
-

Requested permissions:

-
    - {record.scopes.map((scope) => ( -
  • {scope}
  • - ))} -
-
- )} - -
-
- - -
-
- - -
+ +

+ Device Login +

+

+ {clientName} is requesting access to your account. +

+ + {record.scopes.length > 0 && ( +
+

Requested permissions:

+
    + {record.scopes.map((scope) => ( +
  • {scope}
  • + ))} +
+ )} + +
+
+ + +
+
+ + +
-
+ ); } // Default: code entry form return ( -
-
-

- Device Login -

-

- Enter the code displayed on your device. -

+ +

Device Login

+

+ Enter the code displayed on your device. +

-
- - - - -
-
-
+
+ + + + +
+ ); } diff --git a/src/app/signin/page.tsx b/src/app/signin/page.tsx index 446c40c3..4abef8bc 100644 --- a/src/app/signin/page.tsx +++ b/src/app/signin/page.tsx @@ -60,18 +60,16 @@ export default async function SignInPage({ const t = await getTranslations("signIn"); const showDebugProviders = allowDebugAuth; + const providerNames = { + [OIDC_PROVIDER_ID]: "USTC", + github: "GitHub", + google: "Google", + [DEV_DEBUG_PROVIDER_ID]: t("devDebugProvider"), + [DEV_ADMIN_PROVIDER_ID]: t("devAdminProvider"), + } satisfies Record[number], string>; const providers = getSignInProviderIds(showDebugProviders).map((id) => ({ id, - name: - id === OIDC_PROVIDER_ID - ? "USTC" - : id === DEV_DEBUG_PROVIDER_ID - ? t("devDebugProvider") - : id === DEV_ADMIN_PROVIDER_ID - ? t("devAdminProvider") - : id === "github" - ? "GitHub" - : "Google", + name: providerNames[id], })); return ( diff --git a/src/app/u/profile-view.tsx b/src/app/u/profile-view.tsx index 9bcfdcde..08ba85a9 100644 --- a/src/app/u/profile-view.tsx +++ b/src/app/u/profile-view.tsx @@ -41,7 +41,7 @@ export function ProfileView({ return (
- +
@@ -110,25 +110,25 @@ export function ProfileView({ - + {t("contribution.title", { count: totalContributions })} {t("contribution.description")} - -
-
+ +
+
{weeks.map((week, weekIndex) => (
{week.map((day) => (
-
+
{t("contribution.less")} diff --git a/src/auth.ts b/src/auth.ts index 0091fc4e..78d3121a 100644 --- a/src/auth.ts +++ b/src/auth.ts @@ -84,8 +84,8 @@ export async function signIn(providerId?: string, options: SignInOptions = {}) { const debugConfig = getDebugProviderConfig(decision.providerId); result = await authInstance.api.signInEmail({ body: { - email: debugConfig?.email ?? "", - password: debugConfig?.password ?? "", + email: debugConfig.email, + password: debugConfig.password, callbackURL: redirectTo, }, }); diff --git a/src/components/sign-in-link.tsx b/src/components/sign-in-link.tsx new file mode 100644 index 00000000..b6ccd646 --- /dev/null +++ b/src/components/sign-in-link.tsx @@ -0,0 +1,22 @@ +"use client"; + +import { useSearchParams } from "next/navigation"; +import type * as React from "react"; +import { Link, usePathname } from "@/i18n/routing"; +import { + buildCurrentPathCallbackUrl, + buildSignInPageUrl, +} from "@/lib/auth/auth-routing"; + +type SignInLinkProps = Omit, "href"> & { + callbackUrl?: string; +}; + +export function SignInLink({ callbackUrl, ...props }: SignInLinkProps) { + const pathname = usePathname(); + const searchParams = useSearchParams(); + const resolvedCallbackUrl = + callbackUrl ?? buildCurrentPathCallbackUrl(pathname, searchParams); + + return ; +} diff --git a/src/components/user-menu.tsx b/src/components/user-menu.tsx index 7153e80e..5d38ef04 100644 --- a/src/components/user-menu.tsx +++ b/src/components/user-menu.tsx @@ -1,12 +1,14 @@ "use client"; import { User as UserIcon } from "lucide-react"; +import { useRouter } from "next/navigation"; import { useTranslations } from "next-intl"; +import { useState } from "react"; +import { signOutCurrentUser } from "@/app/actions/auth"; import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; import { Button } from "@/components/ui/button"; import { Menu, MenuItem, MenuPopup, MenuTrigger } from "@/components/ui/menu"; import { Link } from "@/i18n/routing"; -import { signOut } from "@/lib/auth/client"; import { cn } from "@/lib/utils"; type UserMenuProps = { @@ -20,12 +22,15 @@ type UserMenuProps = { }; export function UserMenu({ className, initialUser = null }: UserMenuProps) { + const router = useRouter(); const tProfile = useTranslations("profile"); const tSettings = useTranslations("settings"); const tCommon = useTranslations("common"); + const [isSigningOut, setIsSigningOut] = useState(false); + const [isSignedOut, setIsSignedOut] = useState(false); const user = initialUser; - if (!user) { + if (!user || isSignedOut) { return null; } @@ -37,6 +42,21 @@ export function UserMenu({ className, initialUser = null }: UserMenuProps) { ? `/u/id/${user.id}` : "/"; + const handleSignOut = async () => { + if (isSigningOut) return; + + setIsSigningOut(true); + const result = await signOutCurrentUser(); + if (result.error) { + setIsSigningOut(false); + return; + } + + setIsSignedOut(true); + router.push("/"); + router.refresh(); + }; + return (
@@ -75,7 +95,7 @@ export function UserMenu({ className, initialUser = null }: UserMenuProps) { > {tSettings("title")} - signOut({ callbackUrl: "/" })}> + {tProfile("signOut")} diff --git a/src/env.ts b/src/env.ts index 44550722..a7ac602b 100644 --- a/src/env.ts +++ b/src/env.ts @@ -2,313 +2,134 @@ import * as z from "zod"; export const NEXT_PRODUCTION_BUILD_PHASE = "phase-production-build"; -const optionalStringSchema = z.string().optional(); -const optionalUrlSchema = z.string().url().optional(); -const optionalPositiveIntSchema = z.number().int().positive().optional(); +function trimOrUndefined(value: string | undefined) { + const trimmed = value?.trim(); + return trimmed || undefined; +} -const commonEnvSchema = z.object({ - DATABASE_URL: optionalStringSchema, - APP_PUBLIC_ORIGIN: optionalUrlSchema, - APP_CANONICAL_ORIGIN: optionalUrlSchema, - JWT_SECRET: optionalStringSchema, - AUTH_SECRET: optionalStringSchema, - BETTER_AUTH_SECRET: optionalStringSchema, +function parsePositiveIntOrUndefined(value: string | undefined) { + const trimmed = trimOrUndefined(value); + if (!trimmed) return undefined; + return /^\d+$/.test(trimmed) ? Number(trimmed) : Number.NaN; +} - // OAuth providers (optional) - AUTH_GITHUB_ID: optionalStringSchema, - AUTH_GITHUB_SECRET: optionalStringSchema, - AUTH_GOOGLE_ID: optionalStringSchema, - AUTH_GOOGLE_SECRET: optionalStringSchema, - AUTH_OIDC_ISSUER: optionalUrlSchema, - AUTH_OIDC_CLIENT_ID: optionalStringSchema, - AUTH_OIDC_CLIENT_SECRET: optionalStringSchema, - OAUTH_PROXY_SECRET: optionalStringSchema, +/* ------------------------------------------------------------------ */ +/* Schemas */ +/* ------------------------------------------------------------------ */ - // S3 storage (optional outside upload flows) - S3_BUCKET: optionalStringSchema, - AWS_REGION: optionalStringSchema, - AWS_ENDPOINT_URL: optionalUrlSchema, - AWS_ENDPOINT_URL_S3: optionalUrlSchema, - AWS_ACCESS_KEY_ID: optionalStringSchema, - AWS_SECRET_ACCESS_KEY: optionalStringSchema, - AWS_SESSION_TOKEN: optionalStringSchema, +const optionalString = z.string().optional(); +const optionalUrl = z.string().url().optional(); +const optionalPositiveInt = z.number().int().positive().optional(); - // Runtime +const commonEnvSchema = z.object({ + DATABASE_URL: optionalString, + APP_PUBLIC_ORIGIN: optionalUrl, + APP_CANONICAL_ORIGIN: optionalUrl, + AUTH_SECRET: optionalString, + AUTH_GITHUB_ID: optionalString, + AUTH_GITHUB_SECRET: optionalString, + AUTH_GOOGLE_ID: optionalString, + AUTH_GOOGLE_SECRET: optionalString, + AUTH_OIDC_ISSUER: optionalUrl, + AUTH_OIDC_CLIENT_ID: optionalString, + AUTH_OIDC_CLIENT_SECRET: optionalString, + OAUTH_PROXY_SECRET: optionalString, + S3_BUCKET: optionalString, + AWS_REGION: optionalString, + AWS_ENDPOINT_URL_S3: optionalUrl, NODE_ENV: z .enum(["development", "production", "test"]) .default("development"), - LOG_LEVEL: z.enum(["debug", "info", "warn", "error"]).optional(), - WEBHOOK_SECRET: optionalStringSchema, - UPLOAD_TOTAL_QUOTA_MB: optionalPositiveIntSchema, - OAUTH_DEBUG_LOGGING: optionalStringSchema, - E2E_DEBUG_AUTH: optionalStringSchema, - PRISMA_QUERY_DEBUG: optionalStringSchema, - PRISMA_SLOW_QUERY_MS: optionalStringSchema, - - // Deployment metadata - VERCEL: optionalStringSchema, - VERCEL_URL: optionalStringSchema, - VERCEL_PROJECT_PRODUCTION_URL: optionalStringSchema, -}); - -const runtimeRequiredEnvSchema = z - .object({ - DATABASE_URL: z.string().min(1, "DATABASE_URL is required"), - JWT_SECRET: z.string().min(1, "JWT_SECRET is required"), - AUTH_SECRET: optionalStringSchema, - BETTER_AUTH_SECRET: optionalStringSchema, - }) - .refine( - ({ AUTH_SECRET, BETTER_AUTH_SECRET }) => - Boolean(AUTH_SECRET || BETTER_AUTH_SECRET), - { - message: "AUTH_SECRET or BETTER_AUTH_SECRET is required", - path: ["AUTH_SECRET"], - }, - ); - -const authEnvSchema = commonEnvSchema.pick({ - AUTH_GITHUB_ID: true, - AUTH_GITHUB_SECRET: true, - AUTH_GOOGLE_ID: true, - AUTH_GOOGLE_SECRET: true, - AUTH_OIDC_ISSUER: true, - AUTH_OIDC_CLIENT_ID: true, - AUTH_OIDC_CLIENT_SECRET: true, - AUTH_SECRET: true, - BETTER_AUTH_SECRET: true, - OAUTH_PROXY_SECRET: true, - E2E_DEBUG_AUTH: true, - NODE_ENV: true, - VERCEL: true, -}); - -const uploadEnvSchema = commonEnvSchema.pick({ - UPLOAD_TOTAL_QUOTA_MB: true, + UPLOAD_TOTAL_QUOTA_MB: optionalPositiveInt, + E2E_DEBUG_AUTH: optionalString, + VERCEL: optionalString, }); -const storageEnvSchema = commonEnvSchema.pick({ - S3_BUCKET: true, - AWS_REGION: true, - AWS_ENDPOINT_URL: true, - AWS_ENDPOINT_URL_S3: true, - AWS_ACCESS_KEY_ID: true, - AWS_SECRET_ACCESS_KEY: true, - AWS_SESSION_TOKEN: true, +const runtimeRequiredEnvSchema = z.object({ + DATABASE_URL: z.string().min(1, "DATABASE_URL is required"), + AUTH_SECRET: z.string().min(1, "AUTH_SECRET is required"), }); -const envCacheKeys = [ - "DATABASE_URL", - "APP_PUBLIC_ORIGIN", - "APP_CANONICAL_ORIGIN", - "JWT_SECRET", - "AUTH_SECRET", - "BETTER_AUTH_SECRET", - "AUTH_GITHUB_ID", - "AUTH_GITHUB_SECRET", - "AUTH_GOOGLE_ID", - "AUTH_GOOGLE_SECRET", - "AUTH_OIDC_ISSUER", - "AUTH_OIDC_CLIENT_ID", - "AUTH_OIDC_CLIENT_SECRET", - "OAUTH_PROXY_SECRET", - "S3_BUCKET", - "AWS_REGION", - "AWS_ENDPOINT_URL", - "AWS_ENDPOINT_URL_S3", - "AWS_ACCESS_KEY_ID", - "AWS_SECRET_ACCESS_KEY", - "AWS_SESSION_TOKEN", - "NODE_ENV", - "LOG_LEVEL", - "WEBHOOK_SECRET", - "UPLOAD_TOTAL_QUOTA_MB", - "OAUTH_DEBUG_LOGGING", - "E2E_DEBUG_AUTH", - "PRISMA_QUERY_DEBUG", - "PRISMA_SLOW_QUERY_MS", - "VERCEL", - "VERCEL_URL", - "VERCEL_PROJECT_PRODUCTION_URL", -] as const; - -export type Env = z.output; -export type DevEnv = Env; -export type AuthEnv = z.output; -export type UploadEnv = z.output; -export type StorageEnv = z.output; - -type LoadEnvOptions = { - input?: NodeJS.ProcessEnv; - nextPhase?: string; -}; - -let cachedEnv: Env | null = null; -let cachedEnvKey: string | null = null; - -function normalizeOptionalEnvValue(value: string | undefined) { - const trimmed = value?.trim(); - return trimmed ? trimmed : undefined; -} - -function parseOptionalPositiveInt(value: string | undefined) { - const trimmed = normalizeOptionalEnvValue(value); - if (!trimmed) { - return undefined; - } +type Env = z.output; - const parsed = Number.parseInt(trimmed, 10); - return Number.isNaN(parsed) ? Number.NaN : parsed; -} +/* ------------------------------------------------------------------ */ +/* Env normalization */ +/* ------------------------------------------------------------------ */ function normalizeEnvInput(input: NodeJS.ProcessEnv) { - const entries = Object.entries(input).map(([key, value]) => { - if (key === "UPLOAD_TOTAL_QUOTA_MB") { - return [key, parseOptionalPositiveInt(value)] as const; - } - - return [key, normalizeOptionalEnvValue(value)] as const; - }); - - return Object.fromEntries(entries.filter(([, value]) => value !== undefined)); + return Object.fromEntries( + Object.entries(input) + .map(([key, value]) => { + if (key === "UPLOAD_TOTAL_QUOTA_MB") + return [key, parsePositiveIntOrUndefined(value)]; + return [key, trimOrUndefined(value)]; + }) + .filter(([, v]) => v !== undefined), + ); } -function formatEnvIssues(issues: z.ZodIssue[]) { - return issues - .map((issue) => ` ${issue.path.join(".")}: ${issue.message}`) - .join("\n"); -} +/* ------------------------------------------------------------------ */ +/* Load & validate */ +/* ------------------------------------------------------------------ */ -function buildEnvCacheKey(input: NodeJS.ProcessEnv, nextPhase?: string) { - return [ - ...envCacheKeys.map((key) => `${key}=${input[key] ?? ""}`), - `NEXT_PHASE=${nextPhase ?? ""}`, - ].join("\u0000"); +function formatIssues(issues: z.ZodIssue[]) { + return issues.map((i) => ` ${i.path.join(".")}: ${i.message}`).join("\n"); } -function parseEnv( - schema: TSchema, +function parseEnv( + schema: T, input: NodeJS.ProcessEnv, - errorPrefix = "Invalid environment variables", -): z.output { + prefix = "Invalid environment variables", +): z.output { const result = schema.safeParse(normalizeEnvInput(input)); - if (result.success) { - return result.data; - } - - throw new Error(`${errorPrefix}:\n${formatEnvIssues(result.error.issues)}`); + if (result.success) return result.data; + throw new Error(`${prefix}:\n${formatIssues(result.error.issues)}`); } -function validateRuntimeEnv(input: NodeJS.ProcessEnv, nextPhase?: string): Env { - const normalized = normalizeEnvInput(input); - const result = commonEnvSchema.safeParse(normalized); +export function loadEnv( + options: { input?: NodeJS.ProcessEnv; nextPhase?: string } = {}, +): Env { + const input = options.input ?? process.env; + const nextPhase = options.nextPhase ?? trimOrUndefined(input.NEXT_PHASE); + const result = commonEnvSchema.safeParse(normalizeEnvInput(input)); if (!result.success) { - const formatted = formatEnvIssues(result.error.issues); - console.error(`❌ Invalid environment variables:\n${formatted}`); + console.error( + `❌ Invalid environment variables:\n${formatIssues(result.error.issues)}`, + ); throw new Error("Invalid environment variables"); } - const validatedEnv = result.data; + const env = result.data; + // Skip runtime-required checks during production build and development if ( nextPhase === NEXT_PRODUCTION_BUILD_PHASE || - validatedEnv.NODE_ENV === "development" + env.NODE_ENV === "development" ) { - return validatedEnv; + return env; } - const runtimeResult = runtimeRequiredEnvSchema.safeParse(validatedEnv); + const runtimeResult = runtimeRequiredEnvSchema.safeParse(env); if (!runtimeResult.success) { - const formatted = formatEnvIssues(runtimeResult.error.issues); - console.error(`❌ Invalid environment variables:\n${formatted}`); + console.error( + `❌ Invalid environment variables:\n${formatIssues(runtimeResult.error.issues)}`, + ); throw new Error("Invalid environment variables"); } - return validatedEnv; + return env; } -export function loadEnv(options: LoadEnvOptions = {}): Env { - const input = options.input ?? process.env; - const nextPhase = - options.nextPhase ?? getOptionalTrimmedEnv("NEXT_PHASE", input); - - if (input === process.env) { - const nextCacheKey = buildEnvCacheKey(input, nextPhase); - if (cachedEnv && cachedEnvKey === nextCacheKey) { - return cachedEnv; - } - - const validated = validateRuntimeEnv(input, nextPhase); - cachedEnv = validated; - cachedEnvKey = nextCacheKey; - return validated; - } - - return validateRuntimeEnv(input, nextPhase); -} - -export const env = new Proxy({} as Env, { - get(_target, property) { - return loadEnv()[property as keyof Env]; - }, -}); - -export function getOptionalEnvValue( - name: string, - input: NodeJS.ProcessEnv = process.env, -) { - return input[name]; -} +/* ------------------------------------------------------------------ */ +/* Convenience getters */ +/* ------------------------------------------------------------------ */ export function getOptionalTrimmedEnv( name: string, input: NodeJS.ProcessEnv = process.env, ) { - return normalizeOptionalEnvValue(getOptionalEnvValue(name, input)); -} - -export function getOptionalLowercaseEnv( - name: string, - input: NodeJS.ProcessEnv = process.env, -) { - return getOptionalTrimmedEnv(name, input)?.toLowerCase(); -} - -export function getFirstOptionalTrimmedEnv( - names: readonly string[], - input: NodeJS.ProcessEnv = process.env, -) { - for (const name of names) { - const value = getOptionalTrimmedEnv(name, input); - if (value) { - return value; - } - } - - return undefined; -} - -export function getOptionalIntEnv( - name: string, - input: NodeJS.ProcessEnv = process.env, -) { - const parsed = parseOptionalPositiveInt(input[name]); - return Number.isNaN(parsed) ? undefined : parsed; -} - -export function getEnvFlag( - name: string, - input: NodeJS.ProcessEnv = process.env, -) { - return getOptionalTrimmedEnv(name, input) === "1"; -} - -export function isNodeEnv( - value: "development" | "production" | "test", - input: NodeJS.ProcessEnv = process.env, -) { - return getOptionalTrimmedEnv("NODE_ENV", input) === value; + return trimOrUndefined(input[name]); } export function isNextProductionBuildPhase( @@ -319,25 +140,42 @@ export function isNextProductionBuildPhase( ); } -export function getAuthEnv(input: NodeJS.ProcessEnv = process.env): AuthEnv { - return parseEnv(authEnvSchema, input, "Invalid auth environment variables"); +export function getAuthEnv(input: NodeJS.ProcessEnv = process.env) { + return parseEnv( + commonEnvSchema.pick({ + AUTH_GITHUB_ID: true, + AUTH_GITHUB_SECRET: true, + AUTH_GOOGLE_ID: true, + AUTH_GOOGLE_SECRET: true, + AUTH_OIDC_ISSUER: true, + AUTH_OIDC_CLIENT_ID: true, + AUTH_OIDC_CLIENT_SECRET: true, + AUTH_SECRET: true, + OAUTH_PROXY_SECRET: true, + E2E_DEBUG_AUTH: true, + NODE_ENV: true, + VERCEL: true, + }), + input, + "Invalid auth environment variables", + ); } -export function getUploadEnv( - input: NodeJS.ProcessEnv = process.env, -): UploadEnv { +export function getUploadEnv(input: NodeJS.ProcessEnv = process.env) { return parseEnv( - uploadEnvSchema, + commonEnvSchema.pick({ UPLOAD_TOTAL_QUOTA_MB: true }), input, "Invalid upload environment variables", ); } -export function getStorageEnv( - input: NodeJS.ProcessEnv = process.env, -): StorageEnv { +export function getStorageEnv(input: NodeJS.ProcessEnv = process.env) { return parseEnv( - storageEnvSchema, + commonEnvSchema.pick({ + S3_BUCKET: true, + AWS_REGION: true, + AWS_ENDPOINT_URL_S3: true, + }), input, "Invalid storage environment variables", ); diff --git a/src/lib/auth/auth-config.ts b/src/lib/auth/auth-config.ts index 493278d1..0e1ac99a 100644 --- a/src/lib/auth/auth-config.ts +++ b/src/lib/auth/auth-config.ts @@ -1,72 +1,20 @@ -import { - getAuthEnv, - getFirstOptionalTrimmedEnv, - isNextProductionBuildPhase, -} from "@/env"; -import { MCP_TOOLS_SCOPE } from "@/lib/oauth/utils"; -import { getBetterAuthBaseUrl, getPublicOrigin } from "@/lib/site-url"; +import { getAuthEnv, isNextProductionBuildPhase } from "@/env"; const authEnv = getAuthEnv(); export const isDevelopment = authEnv.NODE_ENV === "development"; -export const isE2EDebugAuthEnabled = authEnv.E2E_DEBUG_AUTH === "1"; -export const allowDebugAuth = isDevelopment || isE2EDebugAuthEnabled; +export const allowE2EDebugAuth = authEnv.E2E_DEBUG_AUTH === "1"; +export const allowDebugAuth = isDevelopment || allowE2EDebugAuth; -if (isE2EDebugAuthEnabled && authEnv.VERCEL === "1") { +if (allowE2EDebugAuth && authEnv.VERCEL === "1") { throw new Error( "E2E_DEBUG_AUTH must not be set on Vercel/production hosting", ); } -export const OIDC_ISSUER = - authEnv.AUTH_OIDC_ISSUER ?? "https://sso-proxy.lug.ustc.edu.cn/auth/oauth2"; -export const OIDC_DISCOVERY_URL = `${OIDC_ISSUER.replace(/\/$/, "")}/.well-known/openid-configuration`; -export const AUTH_BASE_URL = getBetterAuthBaseUrl(); -export const AUTH_PUBLIC_ORIGIN = getPublicOrigin(); -export const AUTH_PUBLIC_PROTOCOL = getAuthPublicProtocol(AUTH_PUBLIC_ORIGIN); -export const OAUTH_PROXY_SECRET = authEnv.OAUTH_PROXY_SECRET; -export const AUTH_GITHUB = getProviderCredentials( - authEnv.AUTH_GITHUB_ID, - authEnv.AUTH_GITHUB_SECRET, -); -export const AUTH_GOOGLE = getProviderCredentials( - authEnv.AUTH_GOOGLE_ID, - authEnv.AUTH_GOOGLE_SECRET, -); -export const AUTH_OIDC = { - clientId: authEnv.AUTH_OIDC_CLIENT_ID ?? "", - clientSecret: authEnv.AUTH_OIDC_CLIENT_SECRET ?? "", -}; -export const OAUTH_PROVIDER_SCOPES = [ - "openid", - "profile", - "email", - "offline_access", - MCP_TOOLS_SCOPE, -] as const; - -function getAuthPublicProtocol(origin: string): "http" | "https" { - const protocol = new URL(origin).protocol; - if (protocol === "http:" || protocol === "https:") { - return protocol.slice(0, -1) as "http" | "https"; - } - throw new Error(`Unsupported auth origin protocol: ${protocol}`); -} - -function getProviderCredentials( - clientId: string | undefined, - clientSecret: string | undefined, -) { - return clientId && clientSecret ? { clientId, clientSecret } : null; -} - export function getBetterAuthSecret() { - const secret = getFirstOptionalTrimmedEnv( - ["AUTH_SECRET", "BETTER_AUTH_SECRET"], - authEnv, - ); - if (secret) { - return secret; + if (authEnv.AUTH_SECRET) { + return authEnv.AUTH_SECRET; } if (isNextProductionBuildPhase()) { @@ -74,7 +22,7 @@ export function getBetterAuthSecret() { } if (authEnv.NODE_ENV === "production") { - throw new Error("AUTH_SECRET or BETTER_AUTH_SECRET is required"); + throw new Error("AUTH_SECRET is required"); } return undefined; diff --git a/src/lib/auth/auth-origins.ts b/src/lib/auth/auth-origins.ts index dd0641bc..550cef4e 100644 --- a/src/lib/auth/auth-origins.ts +++ b/src/lib/auth/auth-origins.ts @@ -1,4 +1,3 @@ -import { getOptionalTrimmedEnv } from "@/env"; import { getCanonicalOrigin, getPublicOrigin } from "@/lib/site-url"; const LOCALHOST_DEV_PORT = 3000; @@ -8,32 +7,30 @@ const LOCALHOST_AUTH_ORIGINS = [ ]; const VERCEL_PREVIEW_AUTH_ORIGIN = "https://*.vercel.app"; +function uniqueOrigins(origins: string[]): string[] { + return Array.from(new Set(origins)); +} + +function normalizeOriginOrNull(origin: string): string | null { + try { + return new URL(origin).origin; + } catch { + return null; + } +} + export function getAuthTrustedOrigins(): string[] { - return Array.from( - new Set([ - getPublicOrigin(), - getCanonicalOrigin(), - ...LOCALHOST_AUTH_ORIGINS, - VERCEL_PREVIEW_AUTH_ORIGIN, - ]), - ); + return uniqueOrigins([ + getPublicOrigin(), + getCanonicalOrigin(), + ...LOCALHOST_AUTH_ORIGINS, + VERCEL_PREVIEW_AUTH_ORIGIN, + ]); } export function getAuthAllowedHosts(): string[] { - return Array.from( - new Set( - [ - getPublicOrigin(), - getCanonicalOrigin(), - ...LOCALHOST_AUTH_ORIGINS, - VERCEL_PREVIEW_AUTH_ORIGIN, - ].map((origin) => { - if (origin.includes("://")) { - return new URL(origin).host; - } - return origin.replace(/^https?:\/\//, ""); - }), - ), + return uniqueOrigins( + getAuthTrustedOrigins().map((origin) => new URL(origin).host), ); } @@ -46,24 +43,19 @@ function matchesTrustedOrigin(origin: string, trustedOrigin: string) { return origin === trustedOrigin; } - const protocolSeparator = trustedOrigin.indexOf("://"); - const trustedProtocol = trustedOrigin.slice(0, protocolSeparator); - const trustedHostPattern = trustedOrigin.slice(protocolSeparator + 3); - const trustedHostSuffix = trustedHostPattern.slice(1); - + const trustedUrl = new URL(trustedOrigin); + const trustedHostSuffix = trustedUrl.hostname.slice(1); const url = new URL(origin); return ( - url.protocol === `${trustedProtocol}:` && + url.protocol === trustedUrl.protocol && url.hostname.endsWith(trustedHostSuffix) && url.hostname.length > trustedHostSuffix.length ); } export function isTrustedAuthOrigin(origin: string): boolean { - let normalizedOrigin: string; - try { - normalizedOrigin = new URL(origin).origin; - } catch { + const normalizedOrigin = normalizeOriginOrNull(origin); + if (!normalizedOrigin) { return false; } @@ -71,15 +63,3 @@ export function isTrustedAuthOrigin(origin: string): boolean { matchesTrustedOrigin(normalizedOrigin, trustedOrigin), ); } - -export function getOAuthProxyProductionUrl(): string { - return getCanonicalOrigin(); -} - -export function getOAuthProxyCurrentUrl(): string { - return getPublicOrigin(); -} - -export function getOAuthProxySecret(): string | undefined { - return getOptionalTrimmedEnv("OAUTH_PROXY_SECRET"); -} diff --git a/src/lib/auth/auth-routing.ts b/src/lib/auth/auth-routing.ts index 8ec7e9ef..b015f841 100644 --- a/src/lib/auth/auth-routing.ts +++ b/src/lib/auth/auth-routing.ts @@ -1,3 +1,5 @@ +import { buildSearchParams } from "@/lib/navigation/search-params"; + type SignInSearchParams = { callbackUrl?: string; error?: string; @@ -20,6 +22,14 @@ export function buildSignInPageUrl(callbackUrl: string) { return `/signin?callbackUrl=${encodeURIComponent(callbackUrl)}`; } +export function buildCurrentPathCallbackUrl( + pathname: string, + searchParams?: { toString(): string } | null, +) { + const queryString = searchParams?.toString(); + return queryString ? `${pathname}?${queryString}` : pathname; +} + export function buildSignInRedirectUrl( options: AuthRedirectOptions = {}, fallbackUrl = "/", @@ -32,15 +42,14 @@ export function resolveSignInCallbackUrl(params: SignInSearchParams): string { return params.callbackUrl; } - const authorizeQuery = new URLSearchParams(); - for (const [key, value] of Object.entries(params)) { - if (key === "callbackUrl" || key === "error") { - continue; - } - if (typeof value === "string" && value.length > 0) { - authorizeQuery.set(key, value); - } - } + const { + callbackUrl: _callbackUrl, + error: _error, + ...continuationParams + } = params; + const authorizeQuery = new URLSearchParams( + buildSearchParams({ values: continuationParams }), + ); if (!authorizeQuery.has("client_id") || !authorizeQuery.has("redirect_uri")) { return "/"; @@ -49,7 +58,7 @@ export function resolveSignInCallbackUrl(params: SignInSearchParams): string { return `/oauth/authorize?${authorizeQuery.toString()}`; } -export function isOAuthCallbackContinuation(url: URL): boolean { +function isOAuthCallbackContinuation(url: URL): boolean { const hasState = url.searchParams.has("state"); const hasResult = url.searchParams.has("code") || url.searchParams.has("error"); diff --git a/src/lib/auth/better-auth-options.ts b/src/lib/auth/better-auth-options.ts index ac0ce662..6c979133 100644 --- a/src/lib/auth/better-auth-options.ts +++ b/src/lib/auth/better-auth-options.ts @@ -2,32 +2,21 @@ import { oauthProvider } from "@better-auth/oauth-provider"; import type { betterAuth } from "better-auth"; import { nextCookies } from "better-auth/next-js"; import { genericOAuth, jwt, oAuthProxy } from "better-auth/plugins"; +import { getAuthEnv } from "@/env"; import { - AUTH_GITHUB, - AUTH_GOOGLE, - AUTH_OIDC, - AUTH_PUBLIC_ORIGIN, - AUTH_PUBLIC_PROTOCOL, allowDebugAuth, getBetterAuthSecret, isDevelopment, - OAUTH_PROVIDER_SCOPES, - OAUTH_PROXY_SECRET, - OIDC_DISCOVERY_URL, - OIDC_ISSUER, } from "@/lib/auth/auth-config"; import { getAuthAllowedHosts, getAuthTrustedOrigins, - getOAuthProxyCurrentUrl, - getOAuthProxyProductionUrl, } from "@/lib/auth/auth-origins"; import { createBetterAuthPrismaAdapter } from "@/lib/auth/better-auth-prisma-adapter"; import { - fallbackEmail, + mapGithubProfileToUser, + mapGoogleProfileToUser, mapOidcProfileToUser, - profileImage, - profileName, } from "@/lib/auth/oauth-profile"; import { webhookLoginPlugin } from "@/lib/auth/webhook-login-plugin"; import { prisma } from "@/lib/db/prisma"; @@ -37,6 +26,43 @@ import { getCanonicalOAuthIssuer, getOAuthProviderValidAudiences, } from "@/lib/mcp/urls"; +import { + OAUTH_OPENID_SCOPE, + OAUTH_PROFILE_SCOPE, + OAUTH_PROVIDER_SCOPES, +} from "@/lib/oauth/constants"; +import { getCanonicalOrigin, getPublicOrigin } from "@/lib/site-url"; + +const authEnv = getAuthEnv(); +const AUTH_PUBLIC_ORIGIN = getPublicOrigin(); +const AUTH_PUBLIC_PROTOCOL = getAuthPublicProtocol(AUTH_PUBLIC_ORIGIN); +const OAUTH_PROXY_SECRET = authEnv.OAUTH_PROXY_SECRET; +const OIDC_ISSUER = + authEnv.AUTH_OIDC_ISSUER ?? "https://sso-proxy.lug.ustc.edu.cn/auth/oauth2"; +const OIDC_DISCOVERY_URL = `${OIDC_ISSUER.replace(/\/$/, "")}/.well-known/openid-configuration`; +const AUTH_GITHUB = getProviderCredentials( + authEnv.AUTH_GITHUB_ID, + authEnv.AUTH_GITHUB_SECRET, +); +const AUTH_GOOGLE = getProviderCredentials( + authEnv.AUTH_GOOGLE_ID, + authEnv.AUTH_GOOGLE_SECRET, +); + +function getAuthPublicProtocol(origin: string): "http" | "https" { + const protocol = new URL(origin).protocol; + if (protocol === "http:" || protocol === "https:") { + return protocol.slice(0, -1) as "http" | "https"; + } + throw new Error(`Unsupported auth origin protocol: ${protocol}`); +} + +function getProviderCredentials( + clientId: string | undefined, + clientSecret: string | undefined, +) { + return clientId && clientSecret ? { clientId, clientSecret } : null; +} export function buildBetterAuthOptions() { const options = { @@ -64,26 +90,7 @@ export function buildBetterAuthOptions() { github: { clientId: AUTH_GITHUB.clientId, clientSecret: AUTH_GITHUB.clientSecret, - mapProfileToUser: (profile: { - email?: string | null; - id: string; - name?: string; - login?: string; - avatar_url?: string; - }) => { - const hasEmail = - typeof profile.email === "string" && profile.email.length > 0; - return { - email: hasEmail - ? profile.email - : fallbackEmail("github", profile.id), - name: profileName(profile.name ?? profile.login), - image: profileImage(profile.avatar_url), - // GitHub may return unverified or hidden emails; do not mark - // fallback/local emails as verified. - emailVerified: false, - }; - }, + mapProfileToUser: mapGithubProfileToUser, }, } : {}), @@ -92,27 +99,7 @@ export function buildBetterAuthOptions() { google: { clientId: AUTH_GOOGLE.clientId, clientSecret: AUTH_GOOGLE.clientSecret, - mapProfileToUser: (profile: { - email?: string; - sub: string; - name?: string; - picture?: string; - email_verified?: boolean; - }) => { - const hasEmail = - typeof profile.email === "string" && profile.email.length > 0; - return { - email: hasEmail - ? profile.email - : fallbackEmail("google", profile.sub), - name: profileName(profile.name), - image: profileImage(profile.picture), - emailVerified: - hasEmail && typeof profile.email_verified === "boolean" - ? profile.email_verified - : false, - }; - }, + mapProfileToUser: mapGoogleProfileToUser, }, } : {}), @@ -186,8 +173,8 @@ export function buildBetterAuthOptions() { }, }), oAuthProxy({ - productionURL: getOAuthProxyProductionUrl(), - currentURL: getOAuthProxyCurrentUrl(), + productionURL: getCanonicalOrigin(), + currentURL: AUTH_PUBLIC_ORIGIN, ...(OAUTH_PROXY_SECRET ? { secret: OAUTH_PROXY_SECRET } : {}), }), webhookLoginPlugin(), @@ -243,7 +230,7 @@ export function buildBetterAuthOptions() { scopes: string[]; }) { const claims: Record = {}; - if (scopes.includes("profile")) { + if (scopes.includes(OAUTH_PROFILE_SCOPE)) { const username = user.username; if (typeof username === "string" && username.length > 0) { claims.preferred_username = username; @@ -258,9 +245,9 @@ export function buildBetterAuthOptions() { providerId: "oidc", discoveryUrl: OIDC_DISCOVERY_URL, issuer: OIDC_ISSUER, - clientId: AUTH_OIDC.clientId, - clientSecret: AUTH_OIDC.clientSecret, - scopes: ["openid"], + clientId: authEnv.AUTH_OIDC_CLIENT_ID ?? "", + clientSecret: authEnv.AUTH_OIDC_CLIENT_SECRET ?? "", + scopes: [OAUTH_OPENID_SCOPE], pkce: true, mapProfileToUser: mapOidcProfileToUser, }, @@ -279,9 +266,12 @@ export function buildBetterAuthOptions() { ); } if (isOAuthDebugLogging()) { + const errorMessage = + error instanceof Error ? error.message : String(error); + const errorName = error instanceof Error ? error.name : "unknown"; logOAuthDebug("better-auth.api-error", undefined, { - message: error instanceof Error ? error.message : String(error), - name: error instanceof Error ? error.name : "unknown", + message: errorMessage, + name: errorName, }); } }, diff --git a/src/lib/auth/debug-auth.ts b/src/lib/auth/debug-auth.ts index 72c33fc3..46cf6191 100644 --- a/src/lib/auth/debug-auth.ts +++ b/src/lib/auth/debug-auth.ts @@ -1,12 +1,12 @@ import { hashPassword } from "better-auth/crypto"; -import { getOptionalLowercaseEnv, getOptionalTrimmedEnv } from "@/env"; +import { getOptionalTrimmedEnv } from "@/env"; import { Prisma } from "@/generated/prisma/client"; import { prisma } from "@/lib/db/prisma"; -import { allowDebugAuth, isDevelopment } from "./auth-config"; +import { allowE2EDebugAuth, isDevelopment } from "./auth-config"; import { + DEV_ADMIN_PROVIDER_ID, DEV_DEBUG_PROVIDER_ID, type DebugProviderId, - isDebugProviderId, } from "./provider-ids"; type DebugProviderConfig = { @@ -18,71 +18,102 @@ type DebugProviderConfig = { image: string; }; -const DEV_DEBUG_USERNAME = - getOptionalLowercaseEnv("DEV_DEBUG_USERNAME") ?? "dev-user"; -const DEV_DEBUG_NAME = - getOptionalTrimmedEnv("DEV_DEBUG_NAME") ?? "Dev Debug User"; -const DEV_ADMIN_USERNAME = - getOptionalLowercaseEnv("DEV_ADMIN_USERNAME") ?? "dev-admin"; -const DEV_ADMIN_NAME = - getOptionalTrimmedEnv("DEV_ADMIN_NAME") ?? "Dev Admin User"; -const DEV_DEBUG_EMAIL = - getOptionalLowercaseEnv("DEV_DEBUG_EMAIL") ?? - `${DEV_DEBUG_USERNAME}@debug.local`; -const DEV_ADMIN_EMAIL = - getOptionalLowercaseEnv("DEV_ADMIN_EMAIL") ?? - `${DEV_ADMIN_USERNAME}@debug.local`; - -const DEV_DEBUG_PASSWORD = (() => { - const value = getOptionalTrimmedEnv("DEV_DEBUG_PASSWORD"); - if (allowDebugAuth && !isDevelopment) { - if (!value) { - throw new Error( - "DEV_DEBUG_PASSWORD is required when E2E_DEBUG_AUTH=1 (non-development NODE_ENV)", - ); - } - return value; - } - return value || "dev-debug-password"; -})(); +type DebugProviderDefaults = { + usernameEnv: string; + username: string; + nameEnv: string; + name: string; + emailEnv: string; + passwordEnv: string; + password: string; + isAdmin: boolean; + imageSeed: string; +}; + +const DEBUG_PROVIDER_DEFAULTS: Record = + { + [DEV_DEBUG_PROVIDER_ID]: { + usernameEnv: "DEV_DEBUG_USERNAME", + username: "dev-user", + nameEnv: "DEV_DEBUG_NAME", + name: "Dev User", + emailEnv: "DEV_DEBUG_EMAIL", + passwordEnv: "DEV_DEBUG_PASSWORD", + password: "dev-debug-password", + isAdmin: false, + imageSeed: "life-ustc-dev-user", + }, + [DEV_ADMIN_PROVIDER_ID]: { + usernameEnv: "DEV_ADMIN_USERNAME", + username: "dev-admin", + nameEnv: "DEV_ADMIN_NAME", + name: "Dev Admin User", + emailEnv: "DEV_ADMIN_EMAIL", + passwordEnv: "DEV_ADMIN_PASSWORD", + password: "dev-admin-password", + isAdmin: true, + imageSeed: "life-ustc-dev-admin", + }, + }; + +const requiresExplicitDebugPassword = allowE2EDebugAuth && !isDevelopment; -const DEV_ADMIN_PASSWORD = (() => { - const value = getOptionalTrimmedEnv("DEV_ADMIN_PASSWORD"); - if (allowDebugAuth && !isDevelopment) { +function getLowercaseDebugEnv(envName: string, fallback: string) { + return getOptionalTrimmedEnv(envName)?.toLowerCase() ?? fallback; +} + +function getDebugPassword(envName: string, fallback: string) { + const value = getOptionalTrimmedEnv(envName); + if (requiresExplicitDebugPassword) { if (!value) { throw new Error( - "DEV_ADMIN_PASSWORD is required when E2E_DEBUG_AUTH=1 (non-development NODE_ENV)", + `${envName} is required when E2E_DEBUG_AUTH=1 (non-development NODE_ENV)`, ); } return value; } - return value || "dev-admin-password"; -})(); -export function getDebugProviderConfig( - providerId: DebugProviderId, -): DebugProviderConfig { - if (providerId === DEV_DEBUG_PROVIDER_ID) { - return { - username: DEV_DEBUG_USERNAME, - name: DEV_DEBUG_NAME, - email: DEV_DEBUG_EMAIL, - password: DEV_DEBUG_PASSWORD, - isAdmin: false, - image: "https://api.dicebear.com/9.x/shapes/svg?seed=life-ustc-dev", - }; - } + return value || fallback; +} + +function buildDebugProviderConfig({ + usernameEnv, + username: fallbackUsername, + nameEnv, + name, + emailEnv, + passwordEnv, + password, + isAdmin, + imageSeed, +}: DebugProviderDefaults): DebugProviderConfig { + const username = getLowercaseDebugEnv(usernameEnv, fallbackUsername); return { - username: DEV_ADMIN_USERNAME, - name: DEV_ADMIN_NAME, - email: DEV_ADMIN_EMAIL, - password: DEV_ADMIN_PASSWORD, - isAdmin: true, - image: "https://api.dicebear.com/9.x/shapes/svg?seed=life-ustc-dev-admin", + username, + name: getOptionalTrimmedEnv(nameEnv) ?? name, + email: getLowercaseDebugEnv(emailEnv, `${username}@debug.local`), + password: getDebugPassword(passwordEnv, password), + isAdmin, + image: `https://api.dicebear.com/9.x/shapes/svg?seed=${imageSeed}`, }; } +const DEBUG_PROVIDER_CONFIGS: Record = { + [DEV_DEBUG_PROVIDER_ID]: buildDebugProviderConfig( + DEBUG_PROVIDER_DEFAULTS[DEV_DEBUG_PROVIDER_ID], + ), + [DEV_ADMIN_PROVIDER_ID]: buildDebugProviderConfig( + DEBUG_PROVIDER_DEFAULTS[DEV_ADMIN_PROVIDER_ID], + ), +}; + +export function getDebugProviderConfig( + providerId: DebugProviderId, +): DebugProviderConfig { + return DEBUG_PROVIDER_CONFIGS[providerId]; +} + export async function ensureDebugCredentialUser(providerId: DebugProviderId) { const config = getDebugProviderConfig(providerId); const hashedPassword = await hashPassword(config.password); @@ -157,5 +188,3 @@ export async function ensureDebugCredentialUser(providerId: DebugProviderId) { }, }); } - -export { isDebugProviderId }; diff --git a/src/lib/auth/helpers.ts b/src/lib/auth/helpers.ts index 831af4d2..50ac2d6a 100644 --- a/src/lib/auth/helpers.ts +++ b/src/lib/auth/helpers.ts @@ -44,8 +44,6 @@ export async function resolveApiUserId( // General protected REST endpoints only accept issuer-bound JWT access // tokens. Opaque/no-resource tokens are reserved for the MCP transport, // where resource and scope checks happen in src/lib/mcp/auth.ts. - // Keep the legacy bare-origin audience in one helper while the - // canonical issuer remains path-based at `/api/auth`. verifyOptions: { issuer: getOAuthTokenVerificationIssuers(), audience: getOAuthRestAudienceUrls(), @@ -68,6 +66,23 @@ export async function resolveApiUserId( return session?.user?.id ?? null; } +/** + * Require an authenticated user ID from a request. + * + * Returns `{ userId }` on success, or a 401 Response on failure. + * Use this for routes that need auth but don't need suspension checks. + * + * const auth = await requireAuth(request); + * if (auth instanceof Response) return auth; + * const { userId } = auth; + */ +export async function requireAuth( + request: Request, +): Promise<{ userId: string } | Response> { + const userId = await resolveApiUserId(request); + return userId ? { userId } : unauthorized(); +} + /** * Check auth + suspension for collaborative write routes. * @@ -84,6 +99,6 @@ export async function requireWriteAuth( if (!userId) return unauthorized(); const data = await getViewerAuthDataForUserId(userId); if (!data) return unauthorized(); - if (data?.suspension) return suspensionForbidden(data.suspension.reason); + if (data.suspension) return suspensionForbidden(data.suspension.reason); return { userId }; } diff --git a/src/lib/auth/oauth-profile.ts b/src/lib/auth/oauth-profile.ts index 3f90fc35..43ebe345 100644 --- a/src/lib/auth/oauth-profile.ts +++ b/src/lib/auth/oauth-profile.ts @@ -1,18 +1,34 @@ type OAuthProfile = Record; -export const profileImage = (value: unknown): string | undefined => +type GithubProfile = { + email?: string | null; + id: string; + name?: string; + login?: string; + avatar_url?: string; +}; + +type GoogleProfile = { + email?: string; + sub: string; + name?: string; + picture?: string; + email_verified?: boolean; +}; + +const profileImage = (value: unknown): string | undefined => typeof value === "string" && value.length > 0 ? value : undefined; -export const profileName = (value: unknown): string => +const profileName = (value: unknown): string => typeof value === "string" && value.trim().length > 0 ? value.trim() : ""; -export const fallbackEmail = (provider: string, accountId: unknown): string => +const fallbackEmail = (provider: string, accountId: unknown): string => `${provider}-${String(accountId)}@users.local`; -const firstStringValue = ( - profile: OAuthProfile, - keys: Array, -): string | null => { +const profileEmail = (value: unknown): string | null => + typeof value === "string" && value.length > 0 ? value : null; + +const firstStringValue = (profile: OAuthProfile, keys: readonly string[]) => { for (const key of keys) { const value = profile[key]; if (typeof value === "string" && value.trim().length > 0) { @@ -25,29 +41,44 @@ const firstStringValue = ( return null; }; +function firstBooleanValue(profile: OAuthProfile, keys: readonly string[]) { + for (const key of keys) { + const value = profile[key]; + if (typeof value === "boolean") { + return value; + } + } + return false; +} + +function firstProfileName(profile: OAuthProfile, keys: readonly string[]) { + for (const key of keys) { + const name = profileName(profile[key]); + if (name) { + return name; + } + } + return null; +} + export function mapOidcProfileToUser(profile: OAuthProfile) { const accountId = firstStringValue(profile, ["sub", "id", "user_id"]); if (!accountId) { throw new Error("OIDC profile is missing a stable account identifier"); } - const email = - typeof profile.email === "string" && profile.email.length > 0 - ? profile.email - : null; - const emailVerified = - typeof profile.email_verified === "boolean" - ? profile.email_verified - : typeof profile.emailVerified === "boolean" - ? profile.emailVerified - : false; + const email = profileEmail(profile.email); + const emailVerified = firstBooleanValue(profile, [ + "email_verified", + "emailVerified", + ]); const displayName = - profileName( - profile.name ?? - profile.preferred_username ?? - profile.nickname ?? - profile.email, - ) || `USTC User ${accountId}`; + firstProfileName(profile, [ + "name", + "preferred_username", + "nickname", + "email", + ]) ?? `USTC User ${accountId}`; return { id: accountId, @@ -57,3 +88,28 @@ export function mapOidcProfileToUser(profile: OAuthProfile) { emailVerified: Boolean(email && emailVerified), }; } + +export function mapGithubProfileToUser(profile: GithubProfile) { + const email = profileEmail(profile.email); + return { + email: email ?? fallbackEmail("github", profile.id), + name: profileName(profile.name ?? profile.login), + image: profileImage(profile.avatar_url), + // GitHub may return unverified or hidden emails; do not mark + // fallback/local emails as verified. + emailVerified: false, + }; +} + +export function mapGoogleProfileToUser(profile: GoogleProfile) { + const email = profileEmail(profile.email); + return { + email: email ?? fallbackEmail("google", profile.sub), + name: profileName(profile.name), + image: profileImage(profile.picture), + emailVerified: + email !== null && typeof profile.email_verified === "boolean" + ? profile.email_verified + : false, + }; +} diff --git a/src/lib/auth/provider-ids.ts b/src/lib/auth/provider-ids.ts index 314070cc..f3959f22 100644 --- a/src/lib/auth/provider-ids.ts +++ b/src/lib/auth/provider-ids.ts @@ -11,22 +11,23 @@ const DEBUG_SIGN_IN_PROVIDER_IDS = [ DEV_DEBUG_PROVIDER_ID, DEV_ADMIN_PROVIDER_ID, ] as const; +const DEBUG_SIGN_IN_PROVIDER_ID_SET = new Set( + DEBUG_SIGN_IN_PROVIDER_IDS, +); const ALL_SIGN_IN_PROVIDER_IDS = [ ...PRIMARY_SIGN_IN_PROVIDER_IDS, ...DEBUG_SIGN_IN_PROVIDER_IDS, ] as const; export type DebugProviderId = (typeof DEBUG_SIGN_IN_PROVIDER_IDS)[number]; -export type AuthProviderDecision = +type AuthProviderDecision = | { kind: "none" } | { kind: "debug"; providerId: DebugProviderId } | { kind: "oidc"; providerId: typeof OIDC_PROVIDER_ID } | { kind: "social"; providerId: string }; -export function isDebugProviderId( - providerId: string, -): providerId is DebugProviderId { - return DEBUG_SIGN_IN_PROVIDER_IDS.includes(providerId as DebugProviderId); +function isDebugProviderId(providerId: string): providerId is DebugProviderId { + return DEBUG_SIGN_IN_PROVIDER_ID_SET.has(providerId); } export function resolveAuthProviderDecision( diff --git a/src/lib/mcp/auth.ts b/src/lib/mcp/auth.ts index ae13129e..ede073a6 100644 --- a/src/lib/mcp/auth.ts +++ b/src/lib/mcp/auth.ts @@ -2,9 +2,9 @@ import type { AuthInfo } from "@modelcontextprotocol/sdk/server/auth/types.js"; import { verifyAccessToken as verifyOAuthAccessToken } from "better-auth/oauth2"; import { prisma } from "@/lib/db/prisma"; import { isOAuthDebugLogging, logOAuthDebug } from "@/lib/log/oauth-debug"; +import { MCP_TOOLS_SCOPE } from "@/lib/oauth/constants"; import { hashOAuthClientSecretForDbStorage, - MCP_TOOLS_SCOPE, resourceIndicatorsMatch, } from "@/lib/oauth/utils"; import { @@ -25,12 +25,10 @@ type AuthFailure = { }; function buildBearerHeader({ - request, error, description, scopes, }: { - request: Request; error: string; description: string; scopes?: string[]; @@ -38,7 +36,7 @@ function buildBearerHeader({ const parts = [ `Bearer error="${error}"`, `error_description="${description}"`, - `resource_metadata="${getOAuthProtectedResourceMetadataUrl(request).toString()}"`, + `resource_metadata="${getOAuthProtectedResourceMetadataUrl().toString()}"`, ]; if (scopes && scopes.length > 0) { @@ -48,17 +46,12 @@ function buildBearerHeader({ return parts.join(", "); } -function buildAuthErrorResponse( - request: Request, - failure: AuthFailure, - scopes?: string[], -) { +function buildAuthErrorResponse(failure: AuthFailure, scopes?: string[]) { return new Response(JSON.stringify({ error: failure.error }), { status: failure.status, headers: { "Content-Type": "application/json", "WWW-Authenticate": buildBearerHeader({ - request, error: failure.error, description: failure.description, scopes, @@ -128,13 +121,18 @@ export async function verifyAccessToken( audience: getOAuthMcpAudienceUrls(), }, }); + const jwtClaims = jwt as { + aud?: unknown; + azp?: unknown; + exp?: unknown; + scope?: unknown; + sub?: unknown; + }; const scopeValue = - typeof (jwt as { scope?: unknown }).scope === "string" - ? (jwt as { scope: string }).scope - : ""; + typeof jwtClaims.scope === "string" ? jwtClaims.scope : ""; const scopes = scopeValue.split(" ").filter(Boolean); - const aud = (jwt as { aud?: unknown }).aud; + const aud = jwtClaims.aud; let audValue = ""; if (typeof aud === "string") { audValue = aud; @@ -148,21 +146,15 @@ export async function verifyAccessToken( return { token, - clientId: - typeof (jwt as { azp?: unknown }).azp === "string" - ? (jwt as { azp: string }).azp - : "unknown", + clientId: typeof jwtClaims.azp === "string" ? jwtClaims.azp : "unknown", scopes, expiresAt: - typeof (jwt as { exp?: unknown }).exp === "number" - ? (jwt as { exp: number }).exp + typeof jwtClaims.exp === "number" + ? jwtClaims.exp : Math.floor(Date.now() / 1000) + 60, resource: audValue ? new URL(audValue) : undefined, extra: { - userId: - typeof (jwt as { sub?: unknown }).sub === "string" - ? (jwt as { sub: string }).sub - : undefined, + userId: typeof jwtClaims.sub === "string" ? jwtClaims.sub : undefined, }, }; } catch (err) { @@ -204,7 +196,7 @@ export async function authenticateMcpRequest( const token = parseBearerToken(request); if (!token) { return { - response: buildAuthErrorResponse(request, { + response: buildAuthErrorResponse({ error: INVALID_TOKEN_ERROR, status: 401, description: "Missing bearer token", @@ -214,7 +206,7 @@ export async function authenticateMcpRequest( const authInfo = await verifyAccessToken(request, token); if ("error" in authInfo) { - return { response: buildAuthErrorResponse(request, authInfo) }; + return { response: buildAuthErrorResponse(authInfo) }; } if ( @@ -222,7 +214,7 @@ export async function authenticateMcpRequest( !resourceIndicatorsMatch(authInfo.resource, getOAuthMcpResourceUrl()) ) { return { - response: buildAuthErrorResponse(request, { + response: buildAuthErrorResponse({ error: INVALID_TOKEN_ERROR, status: 401, description: "Access token is not bound to this MCP resource", @@ -233,7 +225,6 @@ export async function authenticateMcpRequest( if (!authInfo.scopes.includes(MCP_TOOLS_SCOPE)) { return { response: buildAuthErrorResponse( - request, { error: INSUFFICIENT_SCOPE_ERROR, status: 403, diff --git a/src/lib/oauth/client-registration.ts b/src/lib/oauth/client-registration.ts index 3f6412c7..3383cbde 100644 --- a/src/lib/oauth/client-registration.ts +++ b/src/lib/oauth/client-registration.ts @@ -1,32 +1,35 @@ import { DEFAULT_OAUTH_CLIENT_SCOPES, MCP_TOOLS_SCOPE, -} from "@/lib/oauth/utils"; - -export const DEFAULT_DYNAMIC_OAUTH_CLIENT_SCOPES = [ - ...DEFAULT_OAUTH_CLIENT_SCOPES, - MCP_TOOLS_SCOPE, -]; + OAUTH_AUTHORIZATION_CODE_GRANT_TYPE, + OAUTH_OFFLINE_ACCESS_SCOPE, + OAUTH_REFRESH_TOKEN_GRANT_TYPE, +} from "@/lib/oauth/constants"; type ValidationErrorResult = { error: string }; type ScopesResult = ValidationErrorResult | { scopes: string[] }; const SUPPORTED_DYNAMIC_CLIENT_SCOPES = new Set([ - ...DEFAULT_DYNAMIC_OAUTH_CLIENT_SCOPES, - "offline_access", + ...DEFAULT_OAUTH_CLIENT_SCOPES, + MCP_TOOLS_SCOPE, + OAUTH_OFFLINE_ACCESS_SCOPE, ]); -export function resolveOAuthClientScopes(options: { - defaultScopes: string[]; - requestedScopes?: string[] | string | null; -}): ScopesResult { - const requestedScopes = - typeof options.requestedScopes === "string" - ? options.requestedScopes.split(" ").filter(Boolean) - : (options.requestedScopes ?? []); +function parseRequestedScopes(input?: string[] | string | null) { + if (typeof input === "string") { + return input.split(" ").filter(Boolean); + } + + return input ?? []; +} + +export function resolveOAuthClientScopes( + requestedScopesInput?: string[] | string | null, +): ScopesResult { + const requestedScopes = parseRequestedScopes(requestedScopesInput); if (requestedScopes.length === 0) { - return { scopes: [...options.defaultScopes] }; + return { scopes: [...DEFAULT_OAUTH_CLIENT_SCOPES] }; } const invalidScopes = requestedScopes.filter( @@ -41,3 +44,9 @@ export function resolveOAuthClientScopes(options: { return { scopes: [...new Set(requestedScopes)] }; } + +export function resolveOAuthClientGrantTypes(scopes: readonly string[]) { + return scopes.includes(OAUTH_OFFLINE_ACCESS_SCOPE) + ? [OAUTH_AUTHORIZATION_CODE_GRANT_TYPE, OAUTH_REFRESH_TOKEN_GRANT_TYPE] + : [OAUTH_AUTHORIZATION_CODE_GRANT_TYPE]; +} diff --git a/src/lib/oauth/constants.ts b/src/lib/oauth/constants.ts new file mode 100644 index 00000000..bbfee673 --- /dev/null +++ b/src/lib/oauth/constants.ts @@ -0,0 +1,44 @@ +export const OAUTH_PUBLIC_CLIENT_AUTH_METHOD = "none"; +export const OAUTH_CLIENT_SECRET_BASIC_AUTH_METHOD = "client_secret_basic"; +export const OAUTH_CLIENT_SECRET_POST_AUTH_METHOD = "client_secret_post"; +export const OAUTH_AUTHORIZATION_CODE_GRANT_TYPE = "authorization_code"; +export const OAUTH_REFRESH_TOKEN_GRANT_TYPE = "refresh_token"; +export const OAUTH_DEVICE_CODE_GRANT_TYPE = + "urn:ietf:params:oauth:grant-type:device_code"; +export const OAUTH_CODE_RESPONSE_TYPE = "code"; +export const OAUTH_DEVICE_AUTHORIZATION_ENDPOINT_PATH = + "/api/auth/oauth2/device-authorization"; +export const OAUTH_TOKEN_ENDPOINT_PATH = "/api/auth/oauth2/token"; +export const OAUTH_OPENID_SCOPE = "openid"; +export const OAUTH_PROFILE_SCOPE = "profile"; +export const OAUTH_EMAIL_SCOPE = "email"; +export const OAUTH_OFFLINE_ACCESS_SCOPE = "offline_access"; +export const MCP_TOOLS_SCOPE = "mcp:tools"; +export const DEFAULT_OAUTH_CLIENT_SCOPES = [ + OAUTH_OPENID_SCOPE, + OAUTH_PROFILE_SCOPE, +] as const; +export const OAUTH_PROVIDER_SCOPES = [ + OAUTH_OPENID_SCOPE, + OAUTH_PROFILE_SCOPE, + OAUTH_EMAIL_SCOPE, + OAUTH_OFFLINE_ACCESS_SCOPE, + MCP_TOOLS_SCOPE, +] as const; +export const SUPPORTED_OAUTH_CLIENT_AUTH_METHODS = [ + OAUTH_PUBLIC_CLIENT_AUTH_METHOD, + OAUTH_CLIENT_SECRET_BASIC_AUTH_METHOD, + OAUTH_CLIENT_SECRET_POST_AUTH_METHOD, +] as const; +const SUPPORTED_OAUTH_CLIENT_AUTH_METHOD_SET = new Set( + SUPPORTED_OAUTH_CLIENT_AUTH_METHODS, +); + +export type SupportedOAuthClientAuthMethod = + (typeof SUPPORTED_OAUTH_CLIENT_AUTH_METHODS)[number]; + +export function isSupportedOAuthClientAuthMethod( + value: string, +): value is SupportedOAuthClientAuthMethod { + return SUPPORTED_OAUTH_CLIENT_AUTH_METHOD_SET.has(value); +} diff --git a/src/lib/oauth/discovery-metadata.ts b/src/lib/oauth/discovery-metadata.ts index 7953ac83..fe931b5e 100644 --- a/src/lib/oauth/discovery-metadata.ts +++ b/src/lib/oauth/discovery-metadata.ts @@ -4,9 +4,12 @@ import { } from "@better-auth/oauth-provider"; import { NextResponse } from "next/server"; import { betterAuthInstance } from "@/auth"; +import { + OAUTH_DEVICE_AUTHORIZATION_ENDPOINT_PATH, + OAUTH_DEVICE_CODE_GRANT_TYPE, +} from "@/lib/oauth/constants"; import { asOAuthProviderMetadataAuth } from "@/lib/oauth/provider-api"; -const DEVICE_CODE_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:device_code"; const DISCOVERY_CORS_HEADERS = { "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "GET, OPTIONS", @@ -46,11 +49,11 @@ function augmentDiscoveryMetadata( return { ...body, - device_authorization_endpoint: `${siteOrigin}/api/auth/oauth2/device-authorization`, + device_authorization_endpoint: `${siteOrigin}${OAUTH_DEVICE_AUTHORIZATION_ENDPOINT_PATH}`, grant_types_supported: [ ...new Set([ ...(body.grant_types_supported ?? []), - DEVICE_CODE_GRANT_TYPE, + OAUTH_DEVICE_CODE_GRANT_TYPE, ]), ], }; @@ -63,9 +66,9 @@ async function buildDiscoveryMetadataResponse( const response = await handler(request); const body = (await response.json()) as DiscoveryMetadata; - return new Response(JSON.stringify(augmentDiscoveryMetadata(request, body)), { + return createDiscoveryJsonResponse(augmentDiscoveryMetadata(request, body), { status: response.status, - headers: withDiscoveryCorsHeaders(response.headers), + headers: response.headers, }); } @@ -77,23 +80,28 @@ export async function getOpenIdMetadataResponse(request: Request) { return buildDiscoveryMetadataResponse(request, openIdConfigMetadataHandler); } -export function getDiscoveryOptionsResponse() { +function getDiscoveryOptionsResponse() { return new Response(null, { status: 204, headers: withDiscoveryCorsHeaders(), }); } -export function getDiscoveryRedirectResponse(url: URL | string, status = 307) { - const response = NextResponse.redirect(url, status); - const headers = withDiscoveryCorsHeaders(response.headers); - response.headers.forEach((_, key) => { - response.headers.delete(key); +function getDiscoveryRedirectResponse(url: URL | string, status = 307) { + return NextResponse.redirect(url, { + status, + headers: withDiscoveryCorsHeaders(), }); - headers.forEach((value, key) => { - response.headers.set(key, value); +} + +export function createDiscoveryJsonResponse( + body: unknown, + init: ResponseInit = {}, +) { + return NextResponse.json(body, { + ...init, + headers: withDiscoveryCorsHeaders(init.headers), }); - return response; } type DiscoveryRouteHandlers = { diff --git a/src/lib/oauth/discovery-routes.ts b/src/lib/oauth/discovery-routes.ts new file mode 100644 index 00000000..d701e2a0 --- /dev/null +++ b/src/lib/oauth/discovery-routes.ts @@ -0,0 +1,65 @@ +import { + getMcpServerUrl, + getOAuthAuthorizationServerMetadataUrl, + getOAuthIssuerUrl, + getOAuthOpenIdConfigurationUrl, + getOAuthProtectedResourceMetadataUrl, +} from "@/lib/mcp/urls"; +import { MCP_TOOLS_SCOPE } from "@/lib/oauth/constants"; +import { + createDiscoveryJsonResponse, + createDiscoveryMetadataRoute, + createDiscoveryRedirectRoute, + getAuthServerMetadataResponse, + getOpenIdMetadataResponse, +} from "@/lib/oauth/discovery-metadata"; + +async function getProtectedResourceMetadataResponse() { + const issuerUrl = getOAuthIssuerUrl(); + + return createDiscoveryJsonResponse({ + resource: getMcpServerUrl().toString(), + authorization_servers: [issuerUrl.toString()], + scopes_supported: [MCP_TOOLS_SCOPE], + bearer_methods_supported: ["header"], + resource_documentation: new URL("/api-docs", issuerUrl).toString(), + }); +} + +const DISCOVERY_TARGETS = { + authServerMetadata: { + type: "metadata", + getResponse: getAuthServerMetadataResponse, + }, + authServerAlias: { + type: "redirect", + resolveUrl: getOAuthAuthorizationServerMetadataUrl, + }, + openIdMetadata: { + type: "metadata", + getResponse: getOpenIdMetadataResponse, + }, + openIdAlias: { + type: "redirect", + resolveUrl: getOAuthOpenIdConfigurationUrl, + }, + protectedResourceMetadata: { + type: "metadata", + getResponse: getProtectedResourceMetadataResponse, + }, + protectedResourceAlias: { + type: "redirect", + resolveUrl: getOAuthProtectedResourceMetadataUrl, + }, +} as const; + +type DiscoveryRouteTarget = keyof typeof DISCOVERY_TARGETS; + +export function createOAuthDiscoveryRoute(target: DiscoveryRouteTarget) { + const route = DISCOVERY_TARGETS[target]; + if (route.type === "metadata") { + return createDiscoveryMetadataRoute(route.getResponse); + } + + return createDiscoveryRedirectRoute(route.resolveUrl); +} diff --git a/src/lib/oauth/logging.ts b/src/lib/oauth/logging.ts deleted file mode 100644 index cad2f5d7..00000000 --- a/src/lib/oauth/logging.ts +++ /dev/null @@ -1,91 +0,0 @@ -import { formatShanghaiTimestamp } from "@/lib/time/shanghai-format"; - -type OAuthLogLevel = "warn" | "error" | "info"; - -type OAuthLogContext = { - route: string; - event: string; - status?: number; - reason?: string; - grantType?: string | null; - registeredAuthMethod?: string | null; - requestAuthMethod?: string | null; - clientId?: string | null; - redirectUri?: string | null; - resource?: string | null; - scope?: string | string[] | null; - userId?: string | null; -}; - -export function logOAuthEvent( - level: OAuthLogLevel, - context: OAuthLogContext, - error?: unknown, -) { - const method = - level === "error" - ? console.error - : level === "warn" - ? console.warn - : console.info; - - const payload = { - timestamp: formatShanghaiTimestamp(new Date()), - environment: - process.env.VERCEL_ENV ?? process.env.NODE_ENV ?? "development", - ...sanitizeOAuthLogContext(context), - }; - const serializedError = serializeError(error); - - if (serializedError) { - method("[oauth]", payload, serializedError); - return; - } - - method("[oauth]", payload); -} - -function sanitizeOAuthLogContext(context: OAuthLogContext) { - return { - ...context, - clientId: summarizeIdentifier(context.clientId), - redirectUri: summarizeUri(context.redirectUri), - resource: summarizeUri(context.resource), - scope: summarizeScope(context.scope), - }; -} - -function summarizeIdentifier(value?: string | null) { - if (!value) return null; - if (value.length <= 8) return `${value.slice(0, 2)}***`; - return `${value.slice(0, 4)}***${value.slice(-4)}`; -} - -function summarizeUri(value?: string | null) { - if (!value) return null; - - try { - const url = new URL(value); - return `${url.origin}${url.pathname}`; - } catch { - return "[invalid-uri]"; - } -} - -function summarizeScope(scope?: string | string[] | null) { - if (!scope) return null; - const values = - typeof scope === "string" ? scope.split(" ").filter(Boolean) : scope; - return values.slice(0, 8); -} - -function serializeError(error: unknown) { - if (!error) return undefined; - if (error instanceof Error) { - return { - name: error.name, - message: error.message, - }; - } - return { error }; -} diff --git a/src/lib/oauth/loopback-redirect.ts b/src/lib/oauth/loopback-redirect.ts index a5f87272..0c436e22 100644 --- a/src/lib/oauth/loopback-redirect.ts +++ b/src/lib/oauth/loopback-redirect.ts @@ -1,14 +1,33 @@ -const IPV4_LOOPBACK_HOSTS = new Set(["127.0.0.1", "localhost"]); -const IPV6_LOOPBACK_HOSTS = new Set(["[::1]", "::1"]); +const LOOPBACK_HOST_FAMILY_BY_HOSTNAME = new Map([ + ["127.0.0.1", "ipv4"], + ["localhost", "ipv4"], + ["[::1]", "ipv6"], + ["::1", "ipv6"], +]); +const STRICT_REDIRECT_URL_PARTS = [ + "protocol", + "port", + "pathname", + "search", + "hash", +] as const; function getLoopbackHostFamily(hostname: string): "ipv4" | "ipv6" | null { - if (IPV4_LOOPBACK_HOSTS.has(hostname)) { - return "ipv4"; - } - if (IPV6_LOOPBACK_HOSTS.has(hostname)) { - return "ipv6"; + return LOOPBACK_HOST_FAMILY_BY_HOSTNAME.get(hostname) ?? null; +} + +function hasSameRedirectTarget(registeredUrl: URL, requestedUrl: URL) { + return STRICT_REDIRECT_URL_PARTS.every( + (part) => registeredUrl[part] === requestedUrl[part], + ); +} + +function parseUrlOrNull(value: string): URL | null { + try { + return new URL(value); + } catch { + return null; } - return null; } /** @@ -20,10 +39,8 @@ export function resolveEquivalentLoopbackRedirectUri( registeredRedirectUris: string[], requestedRedirectUri: string, ): string | null { - let requestedUrl: URL; - try { - requestedUrl = new URL(requestedRedirectUri); - } catch { + const requestedUrl = parseUrlOrNull(requestedRedirectUri); + if (!requestedUrl) { return null; } @@ -33,29 +50,15 @@ export function resolveEquivalentLoopbackRedirectUri( } for (const registeredRedirectUri of registeredRedirectUris) { - let registeredUrl: URL; - try { - registeredUrl = new URL(registeredRedirectUri); - } catch { + const registeredUrl = parseUrlOrNull(registeredRedirectUri); + if (!registeredUrl) { continue; } if (getLoopbackHostFamily(registeredUrl.hostname) !== requestedFamily) { continue; } - if (registeredUrl.protocol !== requestedUrl.protocol) { - continue; - } - if (registeredUrl.port !== requestedUrl.port) { - continue; - } - if (registeredUrl.pathname !== requestedUrl.pathname) { - continue; - } - if (registeredUrl.search !== requestedUrl.search) { - continue; - } - if (registeredUrl.hash !== requestedUrl.hash) { + if (!hasSameRedirectTarget(registeredUrl, requestedUrl)) { continue; } diff --git a/src/lib/oauth/provider-api.ts b/src/lib/oauth/provider-api.ts index 857e1afd..c9c8a89f 100644 --- a/src/lib/oauth/provider-api.ts +++ b/src/lib/oauth/provider-api.ts @@ -1,4 +1,4 @@ -export type AdminCreateOAuthClientInput = { +type AdminCreateOAuthClientInput = { headers: Headers; body: { client_name: string; @@ -14,17 +14,17 @@ export type AdminCreateOAuthClientInput = { }; }; -export type AdminCreateOAuthClientResult = { +type AdminCreateOAuthClientResult = { client_id: string; client_secret?: string | null; }; -export type OAuthClientPublicResult = { +type OAuthClientPublicResult = { client_id: string; client_name?: string | null; }; -export type OAuthProviderApi = { +type OAuthProviderApi = { adminCreateOAuthClient( input: AdminCreateOAuthClientInput, ): Promise; @@ -34,42 +34,34 @@ export type OAuthProviderApi = { }): Promise; }; -export type OAuthProviderMetadataAuth = { +type OAuthProviderMetadataAuth = { api: { getOAuthServerConfig: (...args: unknown[]) => unknown; getOpenIdConfig: (...args: unknown[]) => unknown; }; }; -export type GenericOAuthApi = { +type GenericOAuthApi = { signInWithOAuth2(input: { - body: { - providerId: string; - callbackURL: string; - }; + body: { providerId: string; callbackURL: string }; }): Promise; }; -function asRecord( - value: unknown, - errorMessage: string, -): Record { - if (value && typeof value === "object") { +function asRecord(value: unknown, message: string): Record { + if (value && typeof value === "object") return value as Record; - } - throw new Error(errorMessage); + throw new Error(message); } function requireMethod( target: Record, - methodName: string, - errorMessage: string, + label: string, + method: string, ): (...args: TArgs) => TReturn { - const method = target[methodName]; - if (typeof method !== "function") { - throw new Error(errorMessage); - } - return method.bind(target) as (...args: TArgs) => TReturn; + const fn = target[method]; + if (typeof fn !== "function") + throw new Error(`${label} is unavailable: missing ${method}()`); + return fn.bind(target) as (...args: TArgs) => TReturn; } export function asOAuthProviderApi(api: unknown): OAuthProviderApi { @@ -77,17 +69,16 @@ export function asOAuthProviderApi(api: unknown): OAuthProviderApi { api, "Better Auth OAuth provider API is unavailable: expected an object API surface", ); - return { adminCreateOAuthClient: requireMethod( record, + "Better Auth OAuth provider API", "adminCreateOAuthClient", - "Better Auth OAuth provider API is unavailable: missing adminCreateOAuthClient()", ), getOAuthClientPublic: requireMethod( record, + "Better Auth OAuth provider API", "getOAuthClientPublic", - "Better Auth OAuth provider API is unavailable: missing getOAuthClientPublic()", ), }; } @@ -103,18 +94,17 @@ export function asOAuthProviderMetadataAuth( authRecord.api, "Better Auth OAuth metadata API is unavailable: missing auth.api object", ); - return { api: { getOAuthServerConfig: requireMethod( apiRecord, + "Better Auth OAuth metadata API", "getOAuthServerConfig", - "Better Auth OAuth metadata API is unavailable: missing getOAuthServerConfig()", ), getOpenIdConfig: requireMethod( apiRecord, + "Better Auth OAuth metadata API", "getOpenIdConfig", - "Better Auth OAuth metadata API is unavailable: missing getOpenIdConfig()", ), }, }; @@ -125,12 +115,11 @@ export function asGenericOAuthApi(api: unknown): GenericOAuthApi { api, "Better Auth generic OAuth API is unavailable: expected an object API surface", ); - return { signInWithOAuth2: requireMethod( record, + "Better Auth generic OAuth API", "signInWithOAuth2", - "Better Auth generic OAuth API is unavailable: missing signInWithOAuth2()", ), }; } diff --git a/src/lib/oauth/redirect.ts b/src/lib/oauth/redirect.ts deleted file mode 100644 index 840cf50d..00000000 --- a/src/lib/oauth/redirect.ts +++ /dev/null @@ -1,21 +0,0 @@ -export function buildOAuthErrorRedirectUri({ - redirectUri, - error, - state, - errorDescription, -}: { - redirectUri: string; - error: string; - state?: string; - errorDescription?: string; -}): string { - const url = new URL(redirectUri); - url.searchParams.set("error", error); - if (state) { - url.searchParams.set("state", state); - } - if (errorDescription) { - url.searchParams.set("error_description", errorDescription); - } - return url.toString(); -} diff --git a/src/lib/oauth/utils.ts b/src/lib/oauth/utils.ts index 5637f189..a7e59fbe 100644 --- a/src/lib/oauth/utils.ts +++ b/src/lib/oauth/utils.ts @@ -1,15 +1,5 @@ import { createHash } from "node:crypto"; -export const OAUTH_PUBLIC_CLIENT_AUTH_METHOD = "none"; -export const OAUTH_CLIENT_SECRET_BASIC_AUTH_METHOD = "client_secret_basic"; -export const OAUTH_CLIENT_SECRET_POST_AUTH_METHOD = "client_secret_post"; -export const MCP_TOOLS_SCOPE = "mcp:tools"; -export const DEFAULT_OAUTH_CLIENT_SCOPES = ["openid", "profile"] as const; -export type SupportedOAuthClientAuthMethod = - | typeof OAUTH_PUBLIC_CLIENT_AUTH_METHOD - | typeof OAUTH_CLIENT_SECRET_BASIC_AUTH_METHOD - | typeof OAUTH_CLIENT_SECRET_POST_AUTH_METHOD; - /** * Matches `@better-auth/oauth-provider` default client secret storage when the JWT * plugin is enabled: SHA-256 over UTF-8, base64url without padding (`storeClientSecret: "hashed"`). @@ -21,10 +11,6 @@ export function hashOAuthClientSecretForDbStorage(plainSecret: string): string { export function normalizeResourceIndicator(value: string | URL): string { const parsed = new URL(value); - if (parsed.hash) { - throw new TypeError("Resource indicators must not include fragments"); - } - const protocol = parsed.protocol.toLowerCase(); const hostname = parsed.hostname.toLowerCase(); const port = diff --git a/src/lib/site-url.ts b/src/lib/site-url.ts index 85447479..0dba9ba2 100644 --- a/src/lib/site-url.ts +++ b/src/lib/site-url.ts @@ -18,27 +18,28 @@ function normalizeVercelHost(value: string, envName: string): string { return normalizeAbsoluteOrigin(`https://${trimmed}`, envName); } +function getAbsoluteOriginEnv(envName: string) { + const value = getOptionalTrimmedEnv(envName); + if (!value) return undefined; + return normalizeAbsoluteOrigin(value, envName); +} + +function getVercelOriginEnv(envName: string) { + const value = getOptionalTrimmedEnv(envName); + if (!value) return undefined; + return normalizeVercelHost(value, envName); +} + /** * Public origin of the current deployment. Prefer explicit configuration, then * fall back to Vercel runtime metadata, then localhost for local development. */ export function getPublicOrigin(): string { - const configured = getOptionalTrimmedEnv("APP_PUBLIC_ORIGIN"); - if (configured) { - return normalizeAbsoluteOrigin(configured, "APP_PUBLIC_ORIGIN"); - } - - const legacyBetterAuthUrl = getOptionalTrimmedEnv("BETTER_AUTH_URL"); - if (legacyBetterAuthUrl) { - return normalizeAbsoluteOrigin(legacyBetterAuthUrl, "BETTER_AUTH_URL"); - } - - const vercelUrl = getOptionalTrimmedEnv("VERCEL_URL"); - if (vercelUrl) { - return normalizeVercelHost(vercelUrl, "VERCEL_URL"); - } - - return DEFAULT_LOCAL_ORIGIN; + return ( + getAbsoluteOriginEnv("APP_PUBLIC_ORIGIN") ?? + getVercelOriginEnv("VERCEL_URL") ?? + DEFAULT_LOCAL_ORIGIN + ); } /** @@ -46,26 +47,13 @@ export function getPublicOrigin(): string { * production host even in previews, which is a useful fallback when unset. */ export function getCanonicalOrigin(): string { - const configured = getOptionalTrimmedEnv("APP_CANONICAL_ORIGIN"); - if (configured) { - return normalizeAbsoluteOrigin(configured, "APP_CANONICAL_ORIGIN"); - } - - const vercelProductionUrl = getOptionalTrimmedEnv( - "VERCEL_PROJECT_PRODUCTION_URL", + return ( + getAbsoluteOriginEnv("APP_CANONICAL_ORIGIN") ?? + getVercelOriginEnv("VERCEL_PROJECT_PRODUCTION_URL") ?? + getPublicOrigin() ); - if (vercelProductionUrl) { - return normalizeVercelHost( - vercelProductionUrl, - "VERCEL_PROJECT_PRODUCTION_URL", - ); - } - - return getPublicOrigin(); } export function getBetterAuthBaseUrl(): string { - return new URL("/api/auth", `${getPublicOrigin()}/`) - .toString() - .replace(/\/$/, ""); + return `${getPublicOrigin()}/api/auth`; } diff --git a/tests/e2e/src/app/admin/oauth/test.ts b/tests/e2e/src/app/admin/oauth/test.ts index 32c595a0..b9afb001 100644 --- a/tests/e2e/src/app/admin/oauth/test.ts +++ b/tests/e2e/src/app/admin/oauth/test.ts @@ -53,7 +53,6 @@ test("/admin/oauth 管理员可创建并删除客户端", async ({ page }, testI await expect( page.getByRole("heading", { name: /OAuth 客户端管理|OAuth Clients/i }), ).toBeVisible(); - await page.waitForLoadState("networkidle"); const createBtn = page .getByRole("button", { name: /创建客户端|Create Client/i }) @@ -102,9 +101,15 @@ test("/admin/oauth 管理员可创建并删除客户端", async ({ page }, testI ).toBeVisible(); await captureStepScreenshot(page, testInfo, "admin-oauth-created"); - const clientCard = page.locator('[class*="rounded-2xl"]').filter({ - has: page.getByText(clientName, { exact: true }), - }); + await page.getByRole("button", { name: /完成|Done/i }).click(); + + const clientCard = page + .locator("article") + .filter({ + has: page.getByText(clientName, { exact: true }), + }) + .first(); + await expect(clientCard).toBeVisible(); await clientCard.getByRole("button", { name: /删除|Delete/i }).click(); await expect(page.getByText(clientName)).toHaveCount(0, { @@ -125,7 +130,6 @@ test("/admin/oauth created client shows all required fields", async ({ try { // Force fresh sign-in (ui:true) to avoid stale auth-cache from previous test await signInAsDevAdmin(page, "/admin/oauth", "/admin/oauth", { ui: true }); - await page.waitForLoadState("networkidle"); const createBtn = page .getByRole("button", { name: /创建客户端|Create Client/i }) @@ -157,7 +161,7 @@ test("/admin/oauth created client shows all required fields", async ({ .click(); const clientCard = page - .locator('[class*="rounded"]') + .locator("article, [class*='rounded']") .filter({ has: page.getByText(clientName, { exact: true }), }) diff --git a/tests/e2e/src/app/api/oauth/register/test.ts b/tests/e2e/src/app/api/oauth/register/test.ts index d6df1614..3f08be9d 100644 --- a/tests/e2e/src/app/api/oauth/register/test.ts +++ b/tests/e2e/src/app/api/oauth/register/test.ts @@ -19,6 +19,15 @@ */ import { createHash } from "node:crypto"; import { expect, test } from "@playwright/test"; +import { + MCP_TOOLS_SCOPE, + OAUTH_AUTHORIZATION_CODE_GRANT_TYPE, + OAUTH_CODE_RESPONSE_TYPE, + OAUTH_EMAIL_SCOPE, + OAUTH_OPENID_SCOPE, + OAUTH_PROFILE_SCOPE, + OAUTH_PUBLIC_CLIENT_AUTH_METHOD, +} from "@/lib/oauth/constants"; import { signInAsDebugUser } from "../../../../../utils/auth"; import { PLAYWRIGHT_BASE_URL } from "../../../../../utils/e2e-db"; @@ -32,6 +41,12 @@ const CODE_VERIFIER = "oauth-provider-e2e-verifier-0123456789012345678901234567890123456789"; const LOOPBACK_REDIRECT_URI = "http://127.0.0.1:61000/callback"; const LOOPBACK_LOCALHOST_REDIRECT_URI = "http://localhost:61000/callback"; +const DCR_CLIENT_SCOPE = [ + OAUTH_OPENID_SCOPE, + OAUTH_PROFILE_SCOPE, + OAUTH_EMAIL_SCOPE, + MCP_TOOLS_SCOPE, +].join(" "); test.describe("OAuth provider", () => { test("canonical well-known endpoints are exposed and legacy aliases redirect", async ({ @@ -144,10 +159,10 @@ test.describe("OAuth provider", () => { data: { client_name: `e2e-public-${Date.now()}`, redirect_uris: [REDIRECT_URI], - token_endpoint_auth_method: "none", - grant_types: ["authorization_code"], - response_types: ["code"], - scope: "openid profile email mcp:tools", + token_endpoint_auth_method: OAUTH_PUBLIC_CLIENT_AUTH_METHOD, + grant_types: [OAUTH_AUTHORIZATION_CODE_GRANT_TYPE], + response_types: [OAUTH_CODE_RESPONSE_TYPE], + scope: DCR_CLIENT_SCOPE, }, }, ); @@ -170,10 +185,10 @@ test.describe("OAuth provider", () => { "/api/auth/oauth2/authorize", { params: { - response_type: "code", + response_type: OAUTH_CODE_RESPONSE_TYPE, client_id: clientId, redirect_uri: REDIRECT_URI, - scope: "openid profile email mcp:tools", + scope: DCR_CLIENT_SCOPE, state: "e2e-state", prompt: "consent", code_challenge: generateCodeChallenge(CODE_VERIFIER), @@ -214,7 +229,7 @@ test.describe("OAuth provider", () => { // Exchange code for token. const tokenResponse = await request.post("/api/auth/oauth2/token", { form: { - grant_type: "authorization_code", + grant_type: OAUTH_AUTHORIZATION_CODE_GRANT_TYPE, client_id: clientId, code, code_verifier: CODE_VERIFIER, @@ -247,10 +262,10 @@ test.describe("OAuth provider", () => { data: { client_name: `e2e-loopback-${Date.now()}`, redirect_uris: [LOOPBACK_REDIRECT_URI], - token_endpoint_auth_method: "none", - grant_types: ["authorization_code"], - response_types: ["code"], - scope: "openid profile email mcp:tools", + token_endpoint_auth_method: OAUTH_PUBLIC_CLIENT_AUTH_METHOD, + grant_types: [OAUTH_AUTHORIZATION_CODE_GRANT_TYPE], + response_types: [OAUTH_CODE_RESPONSE_TYPE], + scope: DCR_CLIENT_SCOPE, type: "native", }, }, @@ -271,10 +286,10 @@ test.describe("OAuth provider", () => { "/api/auth/oauth2/authorize", { params: { - response_type: "code", + response_type: OAUTH_CODE_RESPONSE_TYPE, client_id: clientId, redirect_uri: LOOPBACK_LOCALHOST_REDIRECT_URI, - scope: "openid profile email mcp:tools", + scope: DCR_CLIENT_SCOPE, state: "e2e-loopback-state", prompt: "consent", code_challenge: generateCodeChallenge(CODE_VERIFIER), diff --git a/tests/e2e/src/app/signin/test.ts b/tests/e2e/src/app/signin/test.ts index e41b40d4..9b3a07e7 100644 --- a/tests/e2e/src/app/signin/test.ts +++ b/tests/e2e/src/app/signin/test.ts @@ -15,12 +15,44 @@ * - Already authenticated user navigating to /signin redirects away * - jwId is NOT displayed */ -import { expect, test } from "@playwright/test"; -import { signInAsDebugUser } from "../../../utils/auth"; +import { expect, type Page, test } from "@playwright/test"; +import { signInAsDebugUser, signInAsDevAdmin } from "../../../utils/auth"; import { gotoAndWaitForReady } from "../../../utils/page-ready"; import { captureStepScreenshot } from "../../../utils/screenshot"; import { assertPageContract } from "../_shared/page-contract"; +async function expectSignedOutAfterMenuClick(page: Page) { + await page.locator("#app-user-menu").getByRole("button").click(); + await page.getByRole("menuitem", { name: /登出|Sign Out/i }).click(); + + await expect(page).toHaveURL(/\/(?:\?.*)?$/); + + const readSessionState = async (path = "/api/auth/get-session") => { + const response = await page.request.get(path); + if (!response.ok()) { + return `status-${response.status()}`; + } + const session = (await response.json()) as { + user?: { id?: string } | null; + } | null; + return session?.user?.id ? "signed-in" : "signed-out"; + }; + + await expect + .poll( + () => readSessionState("/api/auth/get-session?disableCookieCache=true"), + { timeout: 10_000 }, + ) + .toBe("signed-out"); + await expect.poll(() => readSessionState()).toBe("signed-out"); + + await page.reload({ waitUntil: "domcontentloaded" }); + await expect(page.locator("#app-user-menu")).toHaveCount(0); + await expect( + page.getByRole("link", { name: /^(登录|Sign in)$/i }), + ).toBeVisible(); +} + test("/signin contract", async ({ page }, testInfo) => { await assertPageContract(page, { routePath: "/signin", testInfo }); }); @@ -69,6 +101,22 @@ test("/signin 调试用户按钮可登录", async ({ page }, testInfo) => { await captureStepScreenshot(page, testInfo, "signin/after-login"); }); +test("/signin 调试用户可登出", async ({ page }, testInfo) => { + await signInAsDebugUser(page, "/", "/", { ui: true }); + + await expectSignedOutAfterMenuClick(page); + + await captureStepScreenshot(page, testInfo, "signin/after-sign-out"); +}); + +test("/signin 调试管理员可登出", async ({ page }, testInfo) => { + await signInAsDevAdmin(page, "/", "/", { ui: true }); + + await expectSignedOutAfterMenuClick(page); + + await captureStepScreenshot(page, testInfo, "signin/admin-after-sign-out"); +}); + test("/signin post-login redirects to callbackUrl", async ({ page, }, testInfo) => { diff --git a/tests/e2e/src/app/welcome/test.ts b/tests/e2e/src/app/welcome/test.ts index 74d8a59f..22a66b95 100644 --- a/tests/e2e/src/app/welcome/test.ts +++ b/tests/e2e/src/app/welcome/test.ts @@ -89,7 +89,7 @@ test("/welcome displays required fields", async ({ page }, testInfo) => { // Seed semester name should be selectable await semesterSelector.click(); await expect( - page.getByRole("option", { name: DEV_SEED.semesterNameCn }), + page.getByRole("option", { name: DEV_SEED.semesterNameCn }).first(), ).toBeVisible({ timeout: 5_000 }); // Close dialog await page.keyboard.press("Escape"); diff --git a/tests/e2e/utils/auth.ts b/tests/e2e/utils/auth.ts index d14dbc62..f031c1c7 100644 --- a/tests/e2e/utils/auth.ts +++ b/tests/e2e/utils/auth.ts @@ -69,6 +69,7 @@ async function applyCachedSession( await page.context().addCookies(storageState.cookies); await gotoAndWaitForReady(page, expectedPath); await completeWelcomeProfileIfNeeded(page, role, expectedPath); + await expectPagePath(page, expectedPath); await expectAuthenticatedSession(page, { isAdmin: role === "admin" }); return true; } catch { diff --git a/tests/e2e/utils/e2e-db/oauth.ts b/tests/e2e/utils/e2e-db/oauth.ts index 6198e4fd..00989384 100644 --- a/tests/e2e/utils/e2e-db/oauth.ts +++ b/tests/e2e/utils/e2e-db/oauth.ts @@ -1,4 +1,13 @@ import { prisma } from "@/lib/db/prisma"; +import { + DEFAULT_OAUTH_CLIENT_SCOPES, + OAUTH_AUTHORIZATION_CODE_GRANT_TYPE, + OAUTH_CLIENT_SECRET_BASIC_AUTH_METHOD, + OAUTH_CODE_RESPONSE_TYPE, + OAUTH_PUBLIC_CLIENT_AUTH_METHOD, + OAUTH_REFRESH_TOKEN_GRANT_TYPE, + type SupportedOAuthClientAuthMethod, +} from "@/lib/oauth/constants"; import { generateToken, PLAYWRIGHT_BASE_URL } from "./core"; export async function createOAuthClientFixture( @@ -9,17 +18,14 @@ export async function createOAuthClientFixture( grantTypes?: string[]; clientId?: string; clientSecret?: string; - tokenEndpointAuthMethod?: - | "client_secret_basic" - | "client_secret_post" - | "none"; + tokenEndpointAuthMethod?: SupportedOAuthClientAuthMethod; } = {}, ) { const clientId = options.clientId ?? generateToken(16); const tokenEndpointAuthMethod = - options.tokenEndpointAuthMethod ?? "client_secret_basic"; + options.tokenEndpointAuthMethod ?? OAUTH_CLIENT_SECRET_BASIC_AUTH_METHOD; const clientSecret = - tokenEndpointAuthMethod === "none" + tokenEndpointAuthMethod === OAUTH_PUBLIC_CLIENT_AUTH_METHOD ? null : (options.clientSecret ?? generateToken(24)); const publicClientStoredSecret = generateToken(24); @@ -28,10 +34,10 @@ export async function createOAuthClientFixture( ]; const grantTypes = options.grantTypes ?? - (tokenEndpointAuthMethod === "none" - ? ["authorization_code"] - : ["authorization_code", "refresh_token"]); - const scopes = options.scopes ?? ["openid", "profile"]; + (tokenEndpointAuthMethod === OAUTH_PUBLIC_CLIENT_AUTH_METHOD + ? [OAUTH_AUTHORIZATION_CODE_GRANT_TYPE] + : [OAUTH_AUTHORIZATION_CODE_GRANT_TYPE, OAUTH_REFRESH_TOKEN_GRANT_TYPE]); + const scopes = options.scopes ?? [...DEFAULT_OAUTH_CLIENT_SCOPES]; const name = options.name ?? `e2e-oauth-${Date.now()}`; const client = await prisma.oAuthClient.create({ @@ -39,16 +45,19 @@ export async function createOAuthClientFixture( name, clientId, clientSecret: - tokenEndpointAuthMethod === "none" + tokenEndpointAuthMethod === OAUTH_PUBLIC_CLIENT_AUTH_METHOD ? publicClientStoredSecret : clientSecret, redirectUris, - type: tokenEndpointAuthMethod === "none" ? "public" : "web", + type: + tokenEndpointAuthMethod === OAUTH_PUBLIC_CLIENT_AUTH_METHOD + ? "public" + : "web", tokenEndpointAuthMethod, disabled: false, scopes, grantTypes, - responseTypes: ["code"], + responseTypes: [OAUTH_CODE_RESPONSE_TYPE], requirePKCE: true, metadata: { source: "e2e_fixture" }, }, @@ -65,7 +74,7 @@ export async function createOAuthClientFixture( return { ...client, tokenEndpointAuthMethod: - client.tokenEndpointAuthMethod ?? "client_secret_basic", + client.tokenEndpointAuthMethod ?? OAUTH_CLIENT_SECRET_BASIC_AUTH_METHOD, clientSecret, }; } diff --git a/tests/unit/auth-config.test.ts b/tests/unit/auth-config.test.ts new file mode 100644 index 00000000..5ae8eabd --- /dev/null +++ b/tests/unit/auth-config.test.ts @@ -0,0 +1,44 @@ +import { afterEach, describe, expect, it, vi } from "vitest"; + +describe("auth config", () => { + afterEach(() => { + vi.resetModules(); + vi.unstubAllEnvs(); + }); + + it("allows debug auth in development without enabling E2E mode", async () => { + vi.stubEnv("NODE_ENV", "development"); + vi.stubEnv("E2E_DEBUG_AUTH", ""); + + const { allowDebugAuth, allowE2EDebugAuth, isDevelopment } = await import( + "@/lib/auth/auth-config" + ); + + expect(isDevelopment).toBe(true); + expect(allowE2EDebugAuth).toBe(false); + expect(allowDebugAuth).toBe(true); + }); + + it("allows E2E debug auth outside development when explicitly enabled", async () => { + vi.stubEnv("NODE_ENV", "test"); + vi.stubEnv("E2E_DEBUG_AUTH", "1"); + + const { allowDebugAuth, allowE2EDebugAuth, isDevelopment } = await import( + "@/lib/auth/auth-config" + ); + + expect(isDevelopment).toBe(false); + expect(allowE2EDebugAuth).toBe(true); + expect(allowDebugAuth).toBe(true); + }); + + it("rejects E2E debug auth on Vercel hosting", async () => { + vi.stubEnv("NODE_ENV", "production"); + vi.stubEnv("E2E_DEBUG_AUTH", "1"); + vi.stubEnv("VERCEL", "1"); + + await expect(import("@/lib/auth/auth-config")).rejects.toThrow( + "E2E_DEBUG_AUTH must not be set on Vercel/production hosting", + ); + }); +}); diff --git a/tests/unit/auth-helpers.test.ts b/tests/unit/auth-helpers.test.ts index 5d58a1b1..70e16b26 100644 --- a/tests/unit/auth-helpers.test.ts +++ b/tests/unit/auth-helpers.test.ts @@ -22,14 +22,8 @@ vi.mock("@/lib/auth/viewer-context", () => ({ vi.mock("@/lib/mcp/urls", () => ({ getJwksUrlForOAuthVerification: () => "https://life.example/api/auth/jwks", - getOAuthRestAudienceUrls: () => [ - "https://life.example", - "https://life.example/api/auth", - ], - getOAuthTokenVerificationIssuers: () => [ - "https://life.example/api/auth", - "https://life.example", - ], + getOAuthRestAudienceUrls: () => ["https://life.example/api/auth"], + getOAuthTokenVerificationIssuers: () => ["https://life.example/api/auth"], })); describe("auth helpers", () => { @@ -71,8 +65,8 @@ describe("auth helpers", () => { expect.objectContaining({ jwksUrl: "https://life.example/api/auth/jwks", verifyOptions: { - issuer: ["https://life.example/api/auth", "https://life.example"], - audience: ["https://life.example", "https://life.example/api/auth"], + issuer: ["https://life.example/api/auth"], + audience: ["https://life.example/api/auth"], }, }), ); diff --git a/tests/unit/auth-origins.test.ts b/tests/unit/auth-origins.test.ts index a8a39b61..1305036d 100644 --- a/tests/unit/auth-origins.test.ts +++ b/tests/unit/auth-origins.test.ts @@ -2,9 +2,6 @@ import { afterEach, describe, expect, it, vi } from "vitest"; import { getAuthAllowedHosts, getAuthTrustedOrigins, - getOAuthProxyCurrentUrl, - getOAuthProxyProductionUrl, - getOAuthProxySecret, isTrustedAuthOrigin, } from "@/lib/auth/auth-origins"; @@ -26,22 +23,6 @@ describe("auth origin helpers", () => { ]); }); - it("uses canonical origin as the OAuth proxy production URL", () => { - vi.stubEnv("APP_PUBLIC_ORIGIN", "https://preview-123.vercel.app"); - vi.stubEnv("APP_CANONICAL_ORIGIN", "https://life-ustc.tiankaima.dev"); - - expect(getOAuthProxyProductionUrl()).toBe( - "https://life-ustc.tiankaima.dev", - ); - }); - - it("uses the current public origin as the OAuth proxy current URL", () => { - vi.stubEnv("APP_PUBLIC_ORIGIN", "https://preview-123.vercel.app"); - vi.stubEnv("APP_CANONICAL_ORIGIN", "https://life-ustc.tiankaima.dev"); - - expect(getOAuthProxyCurrentUrl()).toBe("https://preview-123.vercel.app"); - }); - it("returns Better Auth allowed hosts for dynamic base URL resolution", () => { vi.stubEnv("APP_PUBLIC_ORIGIN", "https://preview-123.vercel.app"); vi.stubEnv("APP_CANONICAL_ORIGIN", "https://life-ustc.tiankaima.dev"); @@ -55,14 +36,16 @@ describe("auth origin helpers", () => { ]); }); - it("returns the configured OAuth proxy secret when present", () => { - vi.stubEnv("OAUTH_PROXY_SECRET", "shared-proxy-secret"); - expect(getOAuthProxySecret()).toBe("shared-proxy-secret"); - }); + it("deduplicates matching public and canonical origins", () => { + vi.stubEnv("APP_PUBLIC_ORIGIN", "https://life-ustc.tiankaima.dev"); + vi.stubEnv("APP_CANONICAL_ORIGIN", "https://life-ustc.tiankaima.dev"); - it("ignores blank OAuth proxy secret values", () => { - vi.stubEnv("OAUTH_PROXY_SECRET", " "); - expect(getOAuthProxySecret()).toBeUndefined(); + expect(getAuthTrustedOrigins()).toEqual([ + "https://life-ustc.tiankaima.dev", + "http://localhost:3000", + "http://127.0.0.1:3000", + "https://*.vercel.app", + ]); }); }); diff --git a/tests/unit/auth-provider-routing.test.ts b/tests/unit/auth-provider-routing.test.ts index 2929fb35..61cb2a23 100644 --- a/tests/unit/auth-provider-routing.test.ts +++ b/tests/unit/auth-provider-routing.test.ts @@ -1,5 +1,6 @@ import { describe, expect, it } from "vitest"; import { + buildCurrentPathCallbackUrl, buildSignInPageUrl, buildSignInRedirectUrl, resolveAuthRedirectTarget, @@ -25,6 +26,16 @@ describe("auth provider routing", () => { ); }); + it("builds current-page callback URLs from path and query", () => { + expect( + buildCurrentPathCallbackUrl( + "/sections/123", + new URLSearchParams({ tab: "homeworks", comment: "new" }), + ), + ).toBe("/sections/123?tab=homeworks&comment=new"); + expect(buildCurrentPathCallbackUrl("/courses/456")).toBe("/courses/456"); + }); + it("builds sign-in redirects from the resolved destination", () => { expect( buildSignInRedirectUrl( diff --git a/tests/unit/debug-auth.test.ts b/tests/unit/debug-auth.test.ts new file mode 100644 index 00000000..58d24103 --- /dev/null +++ b/tests/unit/debug-auth.test.ts @@ -0,0 +1,65 @@ +import { afterEach, describe, expect, it, vi } from "vitest"; +import { + DEV_ADMIN_PROVIDER_ID, + DEV_DEBUG_PROVIDER_ID, +} from "@/lib/auth/provider-ids"; + +vi.mock("better-auth/crypto", () => ({ + hashPassword: vi.fn(), +})); + +vi.mock("@/lib/db/prisma", () => ({ + prisma: {}, +})); + +describe("debug auth config", () => { + afterEach(() => { + vi.resetModules(); + vi.unstubAllEnvs(); + }); + + it("builds default debug provider configs", async () => { + const { getDebugProviderConfig } = await import("@/lib/auth/debug-auth"); + + expect(getDebugProviderConfig(DEV_DEBUG_PROVIDER_ID)).toEqual({ + username: "dev-user", + name: "Dev User", + email: "dev-user@debug.local", + password: "dev-debug-password", + isAdmin: false, + image: "https://api.dicebear.com/9.x/shapes/svg?seed=life-ustc-dev-user", + }); + expect(getDebugProviderConfig(DEV_ADMIN_PROVIDER_ID)).toMatchObject({ + username: "dev-admin", + name: "Dev Admin User", + email: "dev-admin@debug.local", + password: "dev-admin-password", + isAdmin: true, + }); + }); + + it("trims and lowercases environment overrides", async () => { + vi.stubEnv("DEV_DEBUG_USERNAME", " Custom-User "); + vi.stubEnv("DEV_DEBUG_NAME", " Custom User "); + vi.stubEnv("DEV_DEBUG_EMAIL", " USER@Example.TEST "); + vi.stubEnv("DEV_DEBUG_PASSWORD", " custom-password "); + + const { getDebugProviderConfig } = await import("@/lib/auth/debug-auth"); + + expect(getDebugProviderConfig(DEV_DEBUG_PROVIDER_ID)).toMatchObject({ + username: "custom-user", + name: "Custom User", + email: "user@example.test", + password: "custom-password", + }); + }); + + it("requires explicit debug passwords for non-development E2E auth", async () => { + vi.stubEnv("NODE_ENV", "test"); + vi.stubEnv("E2E_DEBUG_AUTH", "1"); + + await expect(import("@/lib/auth/debug-auth")).rejects.toThrow( + "DEV_DEBUG_PASSWORD is required when E2E_DEBUG_AUTH=1 (non-development NODE_ENV)", + ); + }); +}); diff --git a/tests/unit/env.test.ts b/tests/unit/env.test.ts index 6faa617d..2f085f1a 100644 --- a/tests/unit/env.test.ts +++ b/tests/unit/env.test.ts @@ -9,7 +9,6 @@ describe("env validation", () => { it("throws for invalid production environment variables", async () => { vi.stubEnv("NODE_ENV", "production"); vi.stubEnv("DATABASE_URL", ""); - vi.stubEnv("JWT_SECRET", ""); vi.stubEnv("AUTH_SECRET", ""); const { loadEnv } = await import("@/env"); @@ -20,7 +19,6 @@ describe("env validation", () => { it("throws for invalid test environment variables", async () => { vi.stubEnv("NODE_ENV", "test"); vi.stubEnv("DATABASE_URL", ""); - vi.stubEnv("JWT_SECRET", ""); vi.stubEnv("AUTH_SECRET", ""); const { loadEnv } = await import("@/env"); @@ -32,7 +30,6 @@ describe("env validation", () => { vi.stubEnv("NODE_ENV", "production"); vi.stubEnv("NEXT_PHASE", "phase-production-build"); vi.stubEnv("DATABASE_URL", ""); - vi.stubEnv("JWT_SECRET", ""); vi.stubEnv("AUTH_SECRET", ""); const { loadEnv } = await import("@/env"); @@ -45,7 +42,6 @@ describe("env validation", () => { it("returns a typed partial environment in development", async () => { vi.stubEnv("NODE_ENV", "development"); vi.stubEnv("DATABASE_URL", ""); - vi.stubEnv("JWT_SECRET", ""); vi.stubEnv("AUTH_SECRET", ""); const { loadEnv } = await import("@/env"); @@ -58,25 +54,56 @@ describe("env validation", () => { it("shares trimmed env helpers across auth/runtime call sites", async () => { vi.stubEnv("NODE_ENV", "development"); vi.stubEnv("DATABASE_URL", ""); - vi.stubEnv("JWT_SECRET", ""); vi.stubEnv("AUTH_SECRET", ""); - vi.stubEnv("FEATURE_FLAG", " 1 "); - vi.stubEnv("FIRST_SECRET", " "); vi.stubEnv("SECOND_SECRET", " value "); - vi.stubEnv("USERNAME", " Dev-User "); - const { - getEnvFlag, - getFirstOptionalTrimmedEnv, - getOptionalLowercaseEnv, - getOptionalTrimmedEnv, - } = await import("@/env"); + const { getOptionalTrimmedEnv } = await import("@/env"); expect(getOptionalTrimmedEnv("SECOND_SECRET")).toBe("value"); - expect(getOptionalLowercaseEnv("USERNAME")).toBe("dev-user"); - expect(getFirstOptionalTrimmedEnv(["FIRST_SECRET", "SECOND_SECRET"])).toBe( - "value", + }); + + it("leaves logger-only settings out of env validation", async () => { + const { loadEnv } = await import("@/env"); + + expect( + loadEnv({ + input: { + NODE_ENV: "development", + LOG_LEVEL: "invalid", + }, + }), + ).toEqual({ + NODE_ENV: "development", + }); + }); + + it("keeps storage env scoped to app-read settings", async () => { + const { getStorageEnv } = await import("@/env"); + + expect( + getStorageEnv({ + S3_BUCKET: " bucket ", + AWS_REGION: " us-east-1 ", + AWS_ENDPOINT_URL_S3: " http://127.0.0.1:9000 ", + AWS_ACCESS_KEY_ID: "sdk-managed", + AWS_SECRET_ACCESS_KEY: "sdk-managed", + AWS_SESSION_TOKEN: "sdk-managed", + }), + ).toEqual({ + S3_BUCKET: "bucket", + AWS_REGION: "us-east-1", + AWS_ENDPOINT_URL_S3: "http://127.0.0.1:9000", + }); + }); + + it("parses upload quota as an exact positive integer", async () => { + const { getUploadEnv } = await import("@/env"); + + expect(getUploadEnv({ UPLOAD_TOTAL_QUOTA_MB: " 2048 " })).toEqual({ + UPLOAD_TOTAL_QUOTA_MB: 2048, + }); + expect(() => getUploadEnv({ UPLOAD_TOTAL_QUOTA_MB: "2048mb" })).toThrow( + "Invalid upload environment variables", ); - expect(getEnvFlag("FEATURE_FLAG")).toBe(true); }); }); diff --git a/tests/unit/oauth-client-registration.test.ts b/tests/unit/oauth-client-registration.test.ts new file mode 100644 index 00000000..5070ba9d --- /dev/null +++ b/tests/unit/oauth-client-registration.test.ts @@ -0,0 +1,65 @@ +import { describe, expect, it } from "vitest"; +import { + resolveOAuthClientGrantTypes, + resolveOAuthClientScopes, +} from "@/lib/oauth/client-registration"; +import { + MCP_TOOLS_SCOPE, + OAUTH_AUTHORIZATION_CODE_GRANT_TYPE, + OAUTH_OFFLINE_ACCESS_SCOPE, + OAUTH_OPENID_SCOPE, + OAUTH_PROFILE_SCOPE, + OAUTH_REFRESH_TOKEN_GRANT_TYPE, +} from "@/lib/oauth/constants"; + +describe("resolveOAuthClientScopes", () => { + it("uses the default OAuth profile scopes when none are requested", () => { + expect(resolveOAuthClientScopes()).toEqual({ + scopes: [OAUTH_OPENID_SCOPE, OAUTH_PROFILE_SCOPE], + }); + }); + + it("deduplicates requested scopes while preserving request order", () => { + expect( + resolveOAuthClientScopes([ + OAUTH_PROFILE_SCOPE, + MCP_TOOLS_SCOPE, + OAUTH_PROFILE_SCOPE, + ]), + ).toEqual({ + scopes: [OAUTH_PROFILE_SCOPE, MCP_TOOLS_SCOPE], + }); + }); + + it("accepts space-delimited requested scopes", () => { + expect( + resolveOAuthClientScopes( + `${OAUTH_OPENID_SCOPE} ${MCP_TOOLS_SCOPE} ${OAUTH_OFFLINE_ACCESS_SCOPE}`, + ), + ).toEqual({ + scopes: [OAUTH_OPENID_SCOPE, MCP_TOOLS_SCOPE, OAUTH_OFFLINE_ACCESS_SCOPE], + }); + }); + + it("rejects unsupported requested scopes", () => { + expect(resolveOAuthClientScopes([OAUTH_OPENID_SCOPE, "email"])).toEqual({ + error: "Unsupported scopes requested: email", + }); + }); + + it("uses authorization-code grants unless offline access is requested", () => { + expect( + resolveOAuthClientGrantTypes([OAUTH_OPENID_SCOPE, OAUTH_PROFILE_SCOPE]), + ).toEqual([OAUTH_AUTHORIZATION_CODE_GRANT_TYPE]); + + expect( + resolveOAuthClientGrantTypes([ + OAUTH_OPENID_SCOPE, + OAUTH_OFFLINE_ACCESS_SCOPE, + ]), + ).toEqual([ + OAUTH_AUTHORIZATION_CODE_GRANT_TYPE, + OAUTH_REFRESH_TOKEN_GRANT_TYPE, + ]); + }); +}); diff --git a/tests/unit/oauth-constants.test.ts b/tests/unit/oauth-constants.test.ts new file mode 100644 index 00000000..6788f991 --- /dev/null +++ b/tests/unit/oauth-constants.test.ts @@ -0,0 +1,38 @@ +import { describe, expect, it } from "vitest"; +import { + isSupportedOAuthClientAuthMethod, + MCP_TOOLS_SCOPE, + OAUTH_CLIENT_SECRET_BASIC_AUTH_METHOD, + OAUTH_CLIENT_SECRET_POST_AUTH_METHOD, + OAUTH_EMAIL_SCOPE, + OAUTH_OFFLINE_ACCESS_SCOPE, + OAUTH_OPENID_SCOPE, + OAUTH_PROFILE_SCOPE, + OAUTH_PROVIDER_SCOPES, + OAUTH_PUBLIC_CLIENT_AUTH_METHOD, +} from "@/lib/oauth/constants"; + +describe("oauth constants", () => { + it("detects supported OAuth client authentication methods", () => { + expect( + isSupportedOAuthClientAuthMethod(OAUTH_CLIENT_SECRET_BASIC_AUTH_METHOD), + ).toBe(true); + expect( + isSupportedOAuthClientAuthMethod(OAUTH_CLIENT_SECRET_POST_AUTH_METHOD), + ).toBe(true); + expect( + isSupportedOAuthClientAuthMethod(OAUTH_PUBLIC_CLIENT_AUTH_METHOD), + ).toBe(true); + expect(isSupportedOAuthClientAuthMethod("client_secret_jwt")).toBe(false); + }); + + it("keeps provider-advertised OAuth scopes in stable order", () => { + expect(OAUTH_PROVIDER_SCOPES).toEqual([ + OAUTH_OPENID_SCOPE, + OAUTH_PROFILE_SCOPE, + OAUTH_EMAIL_SCOPE, + OAUTH_OFFLINE_ACCESS_SCOPE, + MCP_TOOLS_SCOPE, + ]); + }); +}); diff --git a/tests/unit/oauth-debug.test.ts b/tests/unit/oauth-debug.test.ts new file mode 100644 index 00000000..84e3898d --- /dev/null +++ b/tests/unit/oauth-debug.test.ts @@ -0,0 +1,69 @@ +import { afterEach, describe, expect, it, vi } from "vitest"; +import { + getOAuthDebugMode, + sanitizeOAuthRedirectLocation, + summarizeOAuthRedirectUri, +} from "@/lib/log/oauth-debug"; + +describe("oauth debug logging", () => { + afterEach(() => { + vi.unstubAllEnvs(); + }); + + it("parses debug logging modes from env", () => { + vi.stubEnv("OAUTH_DEBUG_LOGGING", ""); + expect(getOAuthDebugMode()).toBe("off"); + + vi.stubEnv("OAUTH_DEBUG_LOGGING", " false "); + expect(getOAuthDebugMode()).toBe("off"); + + vi.stubEnv("OAUTH_DEBUG_LOGGING", "2"); + expect(getOAuthDebugMode()).toBe("verbose"); + + vi.stubEnv("OAUTH_DEBUG_LOGGING", "VERBOSE"); + expect(getOAuthDebugMode()).toBe("verbose"); + + vi.stubEnv("OAUTH_DEBUG_LOGGING", " verbose "); + expect(getOAuthDebugMode()).toBe("verbose"); + + vi.stubEnv("OAUTH_DEBUG_LOGGING", "1"); + expect(getOAuthDebugMode()).toBe("standard"); + }); + + it("redacts sensitive redirect query values", () => { + expect( + sanitizeOAuthRedirectLocation( + "/callback?code=secret&state=ok&access_token=token", + "https://life.example.com/api/auth", + ), + ).toBe( + "https://life.example.com/callback?code=%5BREDACTED%5D&state=ok&access_token=%5BREDACTED%5D", + ); + }); + + it("summarizes redirect URL shape without query values", () => { + expect( + summarizeOAuthRedirectUri( + "https://client.example:8443/callback?state=ok&code=secret", + ), + ).toEqual({ + redirectOrigin: "https://client.example:8443", + redirectHost: "client.example:8443", + redirectHostname: "client.example", + redirectPort: "8443", + redirectPath: "/callback", + redirectQueryKeys: ["code", "state"], + }); + }); + + it("keeps invalid redirect summaries explicit", () => { + expect(summarizeOAuthRedirectUri("not a url")).toEqual({ + redirectOrigin: null, + redirectHost: "invalid_redirect_uri", + redirectHostname: null, + redirectPort: null, + redirectPath: null, + redirectQueryKeys: [], + }); + }); +}); diff --git a/tests/unit/oauth-discovery-metadata.test.ts b/tests/unit/oauth-discovery-metadata.test.ts new file mode 100644 index 00000000..6d91c182 --- /dev/null +++ b/tests/unit/oauth-discovery-metadata.test.ts @@ -0,0 +1,38 @@ +import { afterEach, describe, expect, it, vi } from "vitest"; + +describe("OAuth discovery metadata routes", () => { + afterEach(() => { + vi.resetModules(); + vi.unstubAllEnvs(); + }); + + it("adds discovery CORS headers to redirects without dropping Location", async () => { + vi.stubEnv("DATABASE_URL", "postgresql://unit:unit@127.0.0.1:5432/unit"); + vi.stubEnv("AUTH_SECRET", "unit-test-secret"); + + const { createDiscoveryRedirectRoute } = await import( + "@/lib/oauth/discovery-metadata" + ); + const route = createDiscoveryRedirectRoute( + () => + new URL( + "https://life.example/.well-known/oauth-authorization-server/api/auth", + ), + ); + + const response = await route.GET( + new Request( + "https://life.example/.well-known/oauth-authorization-server", + ), + ); + + expect(response.status).toBe(307); + expect(response.headers.get("location")).toBe( + "https://life.example/.well-known/oauth-authorization-server/api/auth", + ); + expect(response.headers.get("access-control-allow-origin")).toBe("*"); + expect(response.headers.get("access-control-allow-methods")).toBe( + "GET, OPTIONS", + ); + }); +}); diff --git a/tests/unit/oauth-loopback-redirect.test.ts b/tests/unit/oauth-loopback-redirect.test.ts index e157b35b..0cd4209c 100644 --- a/tests/unit/oauth-loopback-redirect.test.ts +++ b/tests/unit/oauth-loopback-redirect.test.ts @@ -11,7 +11,7 @@ describe("oauth loopback redirect normalization", () => { ).toBe("http://127.0.0.1:52877/callback"); }); - it("keeps strict matching for path and port", () => { + it("keeps strict matching for path, port, query, and fragment", () => { expect( resolveEquivalentLoopbackRedirectUri( ["http://127.0.0.1:52877/callback"], @@ -24,6 +24,18 @@ describe("oauth loopback redirect normalization", () => { "http://localhost:52877/other", ), ).toBeNull(); + expect( + resolveEquivalentLoopbackRedirectUri( + ["http://127.0.0.1:52877/callback?code=1#done"], + "http://localhost:52877/callback?code=2#done", + ), + ).toBeNull(); + expect( + resolveEquivalentLoopbackRedirectUri( + ["http://127.0.0.1:52877/callback?code=1#done"], + "http://localhost:52877/callback?code=1#other", + ), + ).toBeNull(); }); it("does not rewrite non-loopback URIs", () => { diff --git a/tests/unit/oauth-profile.test.ts b/tests/unit/oauth-profile.test.ts index 526dade9..b1f0b27c 100644 --- a/tests/unit/oauth-profile.test.ts +++ b/tests/unit/oauth-profile.test.ts @@ -1,5 +1,9 @@ import { describe, expect, it } from "vitest"; -import { mapOidcProfileToUser } from "@/lib/auth/oauth-profile"; +import { + mapGithubProfileToUser, + mapGoogleProfileToUser, + mapOidcProfileToUser, +} from "@/lib/auth/oauth-profile"; describe("OAuth profile mapping", () => { it("accepts sparse USTC OIDC profiles with only an id", () => { @@ -36,4 +40,81 @@ describe("OAuth profile mapping", () => { emailVerified: true, }); }); + + it("accepts camelCase email verification from OIDC profiles", () => { + expect( + mapOidcProfileToUser({ + sub: "abc", + email: "student@example.com", + emailVerified: true, + }).emailVerified, + ).toBe(true); + }); + + it("uses the first non-empty profile display name", () => { + expect( + mapOidcProfileToUser({ + sub: "abc", + name: " ", + preferred_username: " student ", + nickname: "ignored", + }).name, + ).toBe("student"); + }); + + it("maps GitHub profiles without trusting the email verification state", () => { + expect( + mapGithubProfileToUser({ + id: "octocat", + email: "octocat@example.com", + name: " Octo Cat ", + login: "ignored", + avatar_url: "https://example.com/octocat.png", + }), + ).toEqual({ + email: "octocat@example.com", + name: "Octo Cat", + image: "https://example.com/octocat.png", + emailVerified: false, + }); + }); + + it("uses a local fallback email for hidden GitHub emails", () => { + expect( + mapGithubProfileToUser({ + id: "octocat", + login: "octocat", + email: null, + }), + ).toEqual({ + email: "github-octocat@users.local", + name: "octocat", + image: undefined, + emailVerified: false, + }); + }); + + it("maps Google email verification only when an email is present", () => { + expect( + mapGoogleProfileToUser({ + sub: "google-user", + email: "student@example.com", + email_verified: true, + name: "Student", + picture: "https://example.com/google.png", + }), + ).toEqual({ + email: "student@example.com", + name: "Student", + image: "https://example.com/google.png", + emailVerified: true, + }); + + expect( + mapGoogleProfileToUser({ + sub: "google-user", + email_verified: true, + }).emailVerified, + ).toBe(false); + }); }); diff --git a/tests/unit/oauth-utils.test.ts b/tests/unit/oauth-utils.test.ts index 6ba434d7..3b3f5c0f 100644 --- a/tests/unit/oauth-utils.test.ts +++ b/tests/unit/oauth-utils.test.ts @@ -1,5 +1,4 @@ import { describe, expect, it } from "vitest"; -import { buildOAuthErrorRedirectUri } from "@/lib/oauth/redirect"; import { hashOAuthClientSecretForDbStorage, normalizeResourceIndicator, @@ -31,10 +30,10 @@ describe("oauth/utils", () => { ); }); - it("rejects resource indicators with fragments", () => { - expect(() => - normalizeResourceIndicator("https://example.com/api#frag"), - ).toThrow("must not include fragments"); + it("strips fragments from resource indicators", () => { + expect(normalizeResourceIndicator("https://example.com/api#frag")).toBe( + "https://example.com/api", + ); }); it("matches equivalent resource indicators", () => { @@ -57,17 +56,4 @@ describe("oauth/utils", () => { ), ).toBe(false); }); - - it("builds OAuth error redirect URIs", () => { - expect( - buildOAuthErrorRedirectUri({ - redirectUri: "https://client.example/callback", - error: "invalid_scope", - state: "abc123", - errorDescription: "Scope is not allowed", - }), - ).toBe( - "https://client.example/callback?error=invalid_scope&state=abc123&error_description=Scope+is+not+allowed", - ); - }); }); diff --git a/tests/unit/provider-api.test.ts b/tests/unit/provider-api.test.ts index 60a5bc7a..6629eddd 100644 --- a/tests/unit/provider-api.test.ts +++ b/tests/unit/provider-api.test.ts @@ -1,4 +1,11 @@ import { describe, expect, it } from "vitest"; +import { + OAUTH_AUTHORIZATION_CODE_GRANT_TYPE, + OAUTH_CODE_RESPONSE_TYPE, + OAUTH_OPENID_SCOPE, + OAUTH_PROFILE_SCOPE, + OAUTH_PUBLIC_CLIENT_AUTH_METHOD, +} from "@/lib/oauth/constants"; import { asGenericOAuthApi, asOAuthProviderApi, @@ -18,10 +25,10 @@ describe("provider-api guards", () => { body: { client_name: "Client", redirect_uris: ["https://example.com/callback"], - token_endpoint_auth_method: "none", - grant_types: ["authorization_code"], - response_types: ["code"], - scope: "openid profile", + token_endpoint_auth_method: OAUTH_PUBLIC_CLIENT_AUTH_METHOD, + grant_types: [OAUTH_AUTHORIZATION_CODE_GRANT_TYPE], + response_types: [OAUTH_CODE_RESPONSE_TYPE], + scope: `${OAUTH_OPENID_SCOPE} ${OAUTH_PROFILE_SCOPE}`, require_pkce: true, skip_consent: false, enable_end_session: false, diff --git a/tests/unit/signin-callback-url.test.ts b/tests/unit/signin-callback-url.test.ts index ef2dd66b..b3955035 100644 --- a/tests/unit/signin-callback-url.test.ts +++ b/tests/unit/signin-callback-url.test.ts @@ -1,5 +1,11 @@ import { describe, expect, it } from "vitest"; import { resolveSignInCallbackUrl } from "@/lib/auth/auth-routing"; +import { + MCP_TOOLS_SCOPE, + OAUTH_CODE_RESPONSE_TYPE, + OAUTH_OPENID_SCOPE, + OAUTH_PROFILE_SCOPE, +} from "@/lib/oauth/constants"; describe("resolveSignInCallbackUrl", () => { it("prefers explicit callbackUrl", () => { @@ -14,10 +20,10 @@ describe("resolveSignInCallbackUrl", () => { it("reconstructs oauth authorize continuation from raw sign-in params", () => { expect( resolveSignInCallbackUrl({ - response_type: "code", + response_type: OAUTH_CODE_RESPONSE_TYPE, client_id: "client-1", redirect_uri: "http://127.0.0.1:3000/callback", - scope: "openid profile mcp:tools", + scope: `${OAUTH_OPENID_SCOPE} ${OAUTH_PROFILE_SCOPE} ${MCP_TOOLS_SCOPE}`, state: "state-1", code_challenge: "challenge", code_challenge_method: "S256", From d85c204889ae0364c0d50b05da3f992b9a3992a1 Mon Sep 17 00:00:00 2001 From: Tiankai Ma Date: Thu, 28 May 2026 13:31:21 +0800 Subject: [PATCH 2/9] refactor(api): centralize route parsing and query helpers --- public/openapi.generated.json | 14 +- src/app/api/admin/comments/[id]/route.ts | 39 ++-- src/app/api/admin/comments/route.ts | 25 +-- src/app/api/admin/descriptions/route.ts | 27 +-- src/app/api/admin/homeworks/[id]/route.ts | 24 +- src/app/api/admin/homeworks/route.ts | 26 +-- src/app/api/admin/suspensions/[id]/route.ts | 39 ++-- src/app/api/admin/suspensions/route.ts | 6 +- src/app/api/admin/users/[id]/route.ts | 28 +-- src/app/api/admin/users/route.ts | 28 ++- src/app/api/bus/preferences/route.ts | 17 +- src/app/api/bus/route.ts | 10 +- .../calendar-subscriptions/current/route.ts | 15 +- src/app/api/calendar-subscriptions/route.ts | 10 +- src/app/api/comments/[id]/reactions/route.ts | 44 +--- src/app/api/comments/[id]/route.ts | 152 ++++++------- src/app/api/comments/route.ts | 23 +- src/app/api/courses/route.ts | 26 +-- src/app/api/dashboard-links/visit/route.ts | 28 ++- src/app/api/descriptions/route.ts | 15 +- .../api/homeworks/[id]/completion/route.ts | 34 +-- src/app/api/homeworks/[id]/route.ts | 49 ++--- src/app/api/homeworks/route.ts | 37 +--- src/app/api/me/route.ts | 16 +- .../api/me/subscriptions/homeworks/route.ts | 15 +- src/app/api/schedules/route.ts | 65 ++---- src/app/api/sections/calendar.ics/route.ts | 9 +- src/app/api/sections/route.ts | 34 +-- src/app/api/semesters/route.ts | 16 +- src/app/api/teachers/route.ts | 29 ++- src/app/api/todos/[id]/route.ts | 45 ++-- src/app/api/todos/route.ts | 56 ++--- src/app/api/uploads/[id]/download/route.ts | 36 +-- src/app/api/uploads/[id]/route.ts | 74 +++---- src/app/api/uploads/complete/route.ts | 65 +++--- src/app/api/uploads/route.ts | 24 +- src/lib/api/client.ts | 85 +++++--- src/lib/api/helpers.ts | 99 +++++++-- src/lib/api/schemas.ts | 2 - src/lib/course-section-queries.ts | 37 +--- src/lib/course-section-query-filters.ts | 102 +++------ src/lib/current-semester.ts | 48 ++-- src/lib/db/prisma.ts | 39 +--- src/lib/navigation/search-params.ts | 21 +- src/lib/query-filter-helpers.ts | 55 +++++ src/lib/query-helpers.ts | 206 ++++++++---------- src/lib/schedule-queries.ts | 99 ++------- src/lib/time/serialize-date-output.ts | 5 +- src/lib/time/shanghai-format.ts | 5 +- src/shared/lib/time-utils.ts | 58 ++++- tests/e2e/src/app/api/bus/test.ts | 20 +- tests/e2e/src/app/api/courses/test.ts | 11 + tests/e2e/src/app/api/sections/test.ts | 11 + tests/e2e/src/app/api/teachers/test.ts | 11 + .../api/users/[userId]/calendar.ics/test.ts | 17 +- tests/unit/api-client.test.ts | 54 +++++ tests/unit/api-helpers.test.ts | 31 +++ tests/unit/api-schemas.test.ts | 10 +- tests/unit/course-section-queries.test.ts | 84 +++++++ tests/unit/feature-boundaries.test.ts | 6 +- tests/unit/schedule-queries.test.ts | 2 + tests/unit/shanghai-format.test.ts | 27 +++ tests/unit/time-utils.test.ts | 12 + tools/build/openapi/generate-spec.ts | 14 +- 64 files changed, 1169 insertions(+), 1202 deletions(-) delete mode 100644 src/lib/api/schemas.ts create mode 100644 src/lib/query-filter-helpers.ts create mode 100644 tests/unit/api-client.test.ts create mode 100644 tests/unit/shanghai-format.test.ts diff --git a/public/openapi.generated.json b/public/openapi.generated.json index 2f1a3350..fad5b239 100644 --- a/public/openapi.generated.json +++ b/public/openapi.generated.json @@ -159,7 +159,7 @@ "/.well-known/oauth-protected-resource/api/mcp": { "get": { "operationId": "get-.well-known-oauth-protected-resource-api-mcp", - "summary": "", + "summary": "Canonical RFC 9728 protected resource metadata for MCP.", "description": "", "tags": [ "Api" @@ -167,27 +167,19 @@ "parameters": [], "responses": { "200": { - "description": "Successful response", - "content": { - "application/json": { - "schema": {} - } - } + "description": "Response 200" } } }, "options": { "operationId": "options-.well-known-oauth-protected-resource-api-mcp", - "summary": "Canonical RFC 9728 protected resource metadata for MCP.", + "summary": "", "description": "", "tags": [ "Api" ], "parameters": [], "responses": { - "200": { - "description": "Response 200" - }, "204": { "description": "Response 204" } diff --git a/src/app/api/admin/comments/[id]/route.ts b/src/app/api/admin/comments/[id]/route.ts index 3566ff3c..2940b8e4 100644 --- a/src/app/api/admin/comments/[id]/route.ts +++ b/src/app/api/admin/comments/[id]/route.ts @@ -2,33 +2,16 @@ import type { CommentStatus } from "@/generated/prisma/client"; import { withAdminRoute } from "@/lib/admin-utils"; import { jsonResponse, + notFound, + parseResourceIdParam, parseRouteJsonBody, - parseRouteParams, } from "@/lib/api/helpers"; -import { - adminModerateCommentRequestSchema, - resourceIdPathParamsSchema, -} from "@/lib/api/schemas/request-schemas"; -import { writeAuditLog } from "@/lib/audit/write-audit-log"; +import { adminModerateCommentRequestSchema } from "@/lib/api/schemas/request-schemas"; +import { fireAuditLog } from "@/lib/audit/write-audit-log"; import { prisma } from "@/lib/db/prisma"; export const dynamic = "force-dynamic"; -async function parseCommentId( - params: Promise<{ id: string }>, -): Promise { - const parsed = await parseRouteParams( - params, - resourceIdPathParamsSchema, - "Invalid comment ID", - ); - if (parsed instanceof Response) { - return parsed; - } - - return parsed.id; -} - /** * Moderate one comment. * @pathParams resourceIdPathParamsSchema @@ -41,7 +24,7 @@ export async function PATCH( { params }: { params: Promise<{ id: string }> }, ) { return withAdminRoute("Failed to update comment", async (admin) => { - const parsed = await parseCommentId(params); + const parsed = await parseResourceIdParam(params, "comment"); if (parsed instanceof Response) { return parsed; } @@ -55,6 +38,14 @@ export async function PATCH( return parsedBody; } + const existing = await prisma.comment.findUnique({ + where: { id }, + select: { id: true }, + }); + if (!existing) { + return notFound(); + } + const { status, moderationNote } = parsedBody; const updated = await prisma.comment.update({ where: { id }, @@ -67,13 +58,13 @@ export async function PATCH( }, }); - writeAuditLog({ + fireAuditLog({ action: "admin_comment_moderate", userId: admin.userId, targetId: id, targetType: "comment", metadata: { status, moderationNote: moderationNote ?? null }, - }).catch(() => {}); + }); return jsonResponse({ comment: updated }); }); diff --git a/src/app/api/admin/comments/route.ts b/src/app/api/admin/comments/route.ts index 0b0591ed..237474f7 100644 --- a/src/app/api/admin/comments/route.ts +++ b/src/app/api/admin/comments/route.ts @@ -1,10 +1,9 @@ import type { CommentStatus } from "@/generated/prisma/client"; import { withAdminRoute } from "@/lib/admin-utils"; import { - getPagination, getRequestSearchParams, jsonResponse, - parseRouteInput, + parseRouteQuery, } from "@/lib/api/helpers"; import { adminCommentsQuerySchema } from "@/lib/api/schemas/request-schemas"; import { prisma } from "@/lib/db/prisma"; @@ -22,24 +21,22 @@ const STATUS_FILTERS = ["active", "softbanned", "deleted"] as const; export async function GET(request: Request) { return withAdminRoute("Failed to fetch moderation queue", async () => { const searchParams = getRequestSearchParams(request); - const parsedQuery = parseRouteInput( - { - status: searchParams.get("status") ?? undefined, - limit: searchParams.get("limit") ?? undefined, - }, + const parsed = parseRouteQuery( + searchParams, adminCommentsQuerySchema, "Invalid moderation query", - { logErrors: true }, + { + logErrors: true, + pagination: { defaultPageSize: 50, maxPageSize: 200 }, + }, ); - if (parsedQuery instanceof Response) { - return parsedQuery; + if (parsed instanceof Response) { + return parsed; } + const { query: parsedQuery, pagination } = parsed; const status = parsedQuery.status ?? ""; - const { pageSize: limit } = getPagination(searchParams, { - defaultPageSize: 50, - maxPageSize: 200, - }); + const { pageSize: limit } = pagination; const now = new Date(); const where = diff --git a/src/app/api/admin/descriptions/route.ts b/src/app/api/admin/descriptions/route.ts index 1985420f..4054a4b7 100644 --- a/src/app/api/admin/descriptions/route.ts +++ b/src/app/api/admin/descriptions/route.ts @@ -1,9 +1,8 @@ import { withAdminRoute } from "@/lib/admin-utils"; import { - getPagination, getRequestSearchParams, jsonResponse, - parseRouteInput, + parseRouteQuery, } from "@/lib/api/helpers"; import { adminDescriptionsQuerySchema } from "@/lib/api/schemas/request-schemas"; import { prisma } from "@/lib/db/prisma"; @@ -22,28 +21,24 @@ export async function GET(request: Request) { "Failed to fetch descriptions moderation queue", async () => { const searchParams = getRequestSearchParams(request); - const parsedQuery = parseRouteInput( - { - targetType: searchParams.get("targetType") ?? undefined, - hasContent: searchParams.get("hasContent") ?? undefined, - search: searchParams.get("search") ?? undefined, - limit: searchParams.get("limit") ?? undefined, - }, + const parsed = parseRouteQuery( + searchParams, adminDescriptionsQuerySchema, "Invalid descriptions moderation query", - { logErrors: true }, + { + logErrors: true, + pagination: { defaultPageSize: 50, maxPageSize: 200 }, + }, ); - if (parsedQuery instanceof Response) { - return parsedQuery; + if (parsed instanceof Response) { + return parsed; } + const { query: parsedQuery, pagination } = parsed; const targetType = parsedQuery.targetType ?? "all"; const hasContent = parsedQuery.hasContent ?? "withContent"; const search = parsedQuery.search?.trim() ?? ""; - const { pageSize: limit } = getPagination(searchParams, { - defaultPageSize: 50, - maxPageSize: 200, - }); + const { pageSize: limit } = pagination; const targetTypeWhere = targetType === "section" diff --git a/src/app/api/admin/homeworks/[id]/route.ts b/src/app/api/admin/homeworks/[id]/route.ts index 48e26434..9e14101f 100644 --- a/src/app/api/admin/homeworks/[id]/route.ts +++ b/src/app/api/admin/homeworks/[id]/route.ts @@ -1,25 +1,13 @@ import { withAdminRoute } from "@/lib/admin-utils"; -import { jsonResponse, notFound, parseRouteParams } from "@/lib/api/helpers"; -import { resourceIdPathParamsSchema } from "@/lib/api/schemas/request-schemas"; +import { + jsonResponse, + notFound, + parseResourceIdParam, +} from "@/lib/api/helpers"; import { prisma } from "@/lib/db/prisma"; export const dynamic = "force-dynamic"; -async function parseHomeworkId( - params: Promise<{ id: string }>, -): Promise { - const parsed = await parseRouteParams( - params, - resourceIdPathParamsSchema, - "Invalid homework ID", - ); - if (parsed instanceof Response) { - return parsed; - } - - return parsed.id; -} - /** * Soft delete one homework (admin). * @pathParams resourceIdPathParamsSchema @@ -31,7 +19,7 @@ export async function DELETE( { params }: { params: Promise<{ id: string }> }, ) { return withAdminRoute("Failed to delete homework (admin)", async (admin) => { - const parsed = await parseHomeworkId(params); + const parsed = await parseResourceIdParam(params, "homework"); if (parsed instanceof Response) { return parsed; } diff --git a/src/app/api/admin/homeworks/route.ts b/src/app/api/admin/homeworks/route.ts index 50b675ef..a475b573 100644 --- a/src/app/api/admin/homeworks/route.ts +++ b/src/app/api/admin/homeworks/route.ts @@ -1,9 +1,8 @@ import { withAdminRoute } from "@/lib/admin-utils"; import { - getPagination, getRequestSearchParams, jsonResponse, - parseRouteInput, + parseRouteQuery, } from "@/lib/api/helpers"; import { adminHomeworksQuerySchema } from "@/lib/api/schemas/request-schemas"; import { prisma } from "@/lib/db/prisma"; @@ -22,25 +21,22 @@ export async function GET(request: Request) { "Failed to fetch homework moderation queue", async () => { const searchParams = getRequestSearchParams(request); - const parsedQuery = parseRouteInput( - { - status: searchParams.get("status") ?? undefined, - limit: searchParams.get("limit") ?? undefined, - search: searchParams.get("search") ?? undefined, - }, + const parsed = parseRouteQuery( + searchParams, adminHomeworksQuerySchema, "Invalid homework moderation query", - { logErrors: true }, + { + logErrors: true, + pagination: { defaultPageSize: 50, maxPageSize: 200 }, + }, ); - if (parsedQuery instanceof Response) { - return parsedQuery; + if (parsed instanceof Response) { + return parsed; } + const { query: parsedQuery, pagination } = parsed; const status = parsedQuery.status ?? "all"; - const { pageSize: limit } = getPagination(searchParams, { - defaultPageSize: 50, - maxPageSize: 200, - }); + const { pageSize: limit } = pagination; const search = parsedQuery.search?.trim() ?? ""; const deletedAtFilter = diff --git a/src/app/api/admin/suspensions/[id]/route.ts b/src/app/api/admin/suspensions/[id]/route.ts index 40a2220a..a62de860 100644 --- a/src/app/api/admin/suspensions/[id]/route.ts +++ b/src/app/api/admin/suspensions/[id]/route.ts @@ -1,26 +1,14 @@ import { withAdminRoute } from "@/lib/admin-utils"; -import { jsonResponse, parseRouteParams } from "@/lib/api/helpers"; -import { resourceIdPathParamsSchema } from "@/lib/api/schemas/request-schemas"; -import { writeAuditLog } from "@/lib/audit/write-audit-log"; +import { + jsonResponse, + notFound, + parseResourceIdParam, +} from "@/lib/api/helpers"; +import { fireAuditLog } from "@/lib/audit/write-audit-log"; import { prisma } from "@/lib/db/prisma"; export const dynamic = "force-dynamic"; -async function parseSuspensionId( - params: Promise<{ id: string }>, -): Promise { - const parsed = await parseRouteParams( - params, - resourceIdPathParamsSchema, - "Invalid suspension ID", - ); - if (parsed instanceof Response) { - return parsed; - } - - return parsed.id; -} - /** * Lift one suspension. * @pathParams resourceIdPathParamsSchema @@ -32,11 +20,20 @@ export async function PATCH( { params }: { params: Promise<{ id: string }> }, ) { return withAdminRoute("Failed to lift suspension", async (admin) => { - const parsed = await parseSuspensionId(params); + const parsed = await parseResourceIdParam(params, "suspension"); if (parsed instanceof Response) { return parsed; } const id = parsed; + + const existing = await prisma.userSuspension.findUnique({ + where: { id }, + select: { id: true }, + }); + if (!existing) { + return notFound(); + } + const suspension = await prisma.userSuspension.update({ where: { id }, data: { @@ -45,13 +42,13 @@ export async function PATCH( }, }); - writeAuditLog({ + fireAuditLog({ action: "admin_user_unsuspend", userId: admin.userId, targetId: suspension.userId, targetType: "user", metadata: { suspensionId: id }, - }).catch(() => {}); + }); return jsonResponse({ suspension }); }); diff --git a/src/app/api/admin/suspensions/route.ts b/src/app/api/admin/suspensions/route.ts index 0998dbeb..b4a72b43 100644 --- a/src/app/api/admin/suspensions/route.ts +++ b/src/app/api/admin/suspensions/route.ts @@ -1,7 +1,7 @@ import { withAdminRoute } from "@/lib/admin-utils"; import { jsonResponse, notFound, parseRouteJsonBody } from "@/lib/api/helpers"; import { adminCreateSuspensionRequestSchema } from "@/lib/api/schemas/request-schemas"; -import { writeAuditLog } from "@/lib/audit/write-audit-log"; +import { fireAuditLog } from "@/lib/audit/write-audit-log"; import { prisma } from "@/lib/db/prisma"; import { parseDateInput } from "@/lib/time/parse-date-input"; @@ -67,13 +67,13 @@ export async function POST(request: Request) { }, }); - writeAuditLog({ + fireAuditLog({ action: "admin_user_suspend", userId: admin.userId, targetId: userId, targetType: "user", metadata: { reason: parsedBody.reason ?? null }, - }).catch(() => {}); + }); return jsonResponse({ suspension }); }); diff --git a/src/app/api/admin/users/[id]/route.ts b/src/app/api/admin/users/[id]/route.ts index 5df16903..2996c769 100644 --- a/src/app/api/admin/users/[id]/route.ts +++ b/src/app/api/admin/users/[id]/route.ts @@ -1,15 +1,11 @@ -import { NextResponse } from "next/server"; import { withAdminRoute } from "@/lib/admin-utils"; import { badRequest, jsonResponse, - parseRouteInput, + parseResourceIdParam, parseRouteJsonBody, } from "@/lib/api/helpers"; -import { - adminUpdateUserRequestSchema, - resourceIdPathParamsSchema, -} from "@/lib/api/schemas/request-schemas"; +import { adminUpdateUserRequestSchema } from "@/lib/api/schemas/request-schemas"; import { prisma } from "@/lib/db/prisma"; export const dynamic = "force-dynamic"; @@ -27,22 +23,6 @@ function normalizeUsername(value: unknown) { return trimmed ? trimmed : null; } -async function parseUserId( - params: Promise<{ id: string }>, -): Promise { - const raw = await params; - const parsed = parseRouteInput( - raw, - resourceIdPathParamsSchema, - "Invalid user ID", - ); - if (parsed instanceof Response) { - return badRequest("Invalid user ID"); - } - - return parsed.id; -} - /** * Update one user. * @pathParams resourceIdPathParamsSchema @@ -55,8 +35,8 @@ export async function PATCH( { params }: { params: Promise<{ id: string }> }, ) { return withAdminRoute("Failed to update user", async () => { - const parsed = await parseUserId(params); - if (parsed instanceof NextResponse) { + const parsed = await parseResourceIdParam(params, "user"); + if (parsed instanceof Response) { return parsed; } const id = parsed; diff --git a/src/app/api/admin/users/route.ts b/src/app/api/admin/users/route.ts index bf434d96..bae1a540 100644 --- a/src/app/api/admin/users/route.ts +++ b/src/app/api/admin/users/route.ts @@ -3,10 +3,9 @@ import { ADMIN_USERS_PAGE_SIZE } from "@/app/admin/users/constants"; import { withAdminRoute } from "@/lib/admin-utils"; import { buildPaginatedResponse, - getPagination, getRequestSearchParams, jsonResponse, - parseRouteInput, + parseRouteQuery, } from "@/lib/api/helpers"; import { adminUsersQuerySchema } from "@/lib/api/schemas/request-schemas"; import { prisma } from "@/lib/db/prisma"; @@ -23,24 +22,23 @@ export const dynamic = "force-dynamic"; export async function GET(request: NextRequest) { return withAdminRoute("Failed to fetch users", async () => { const searchParams = getRequestSearchParams(request); - const parsedQuery = parseRouteInput( - { - search: searchParams.get("search") ?? undefined, - page: searchParams.get("page") ?? undefined, - limit: searchParams.get("limit") ?? undefined, - }, + const parsed = parseRouteQuery( + searchParams, adminUsersQuerySchema, "Invalid user query", - { logErrors: true }, + { + logErrors: true, + pagination: { + defaultPageSize: ADMIN_USERS_PAGE_SIZE, + maxPageSize: 100, + }, + }, ); - if (parsedQuery instanceof Response) { - return parsedQuery; + if (parsed instanceof Response) { + return parsed; } - const pagination = getPagination(searchParams, { - defaultPageSize: ADMIN_USERS_PAGE_SIZE, - maxPageSize: 100, - }); + const { query: parsedQuery, pagination } = parsed; const search = parsedQuery.search ?? ""; const where = search ? { diff --git a/src/app/api/bus/preferences/route.ts b/src/app/api/bus/preferences/route.ts index 33753cfe..3bb3ab43 100644 --- a/src/app/api/bus/preferences/route.ts +++ b/src/app/api/bus/preferences/route.ts @@ -6,10 +6,9 @@ import { handleRouteError, jsonResponse, parseRouteJsonBody, - unauthorized, } from "@/lib/api/helpers"; import { busPreferenceRequestSchema } from "@/lib/api/schemas/request-schemas"; -import { resolveApiUserId } from "@/lib/auth/helpers"; +import { requireAuth } from "@/lib/auth/helpers"; export const dynamic = "force-dynamic"; @@ -19,10 +18,9 @@ export const dynamic = "force-dynamic"; * @response 401:openApiErrorSchema */ export async function GET(request: Request) { - const userId = await resolveApiUserId(request); - if (!userId) { - return unauthorized(); - } + const auth = await requireAuth(request); + if (auth instanceof Response) return auth; + const { userId } = auth; try { const preference = await getBusPreference(userId); @@ -40,10 +38,9 @@ export async function GET(request: Request) { * @response 400:openApiErrorSchema */ export async function POST(request: Request) { - const userId = await resolveApiUserId(request); - if (!userId) { - return unauthorized(); - } + const auth = await requireAuth(request); + if (auth instanceof Response) return auth; + const { userId } = auth; const parsedBody = await parseRouteJsonBody( request, diff --git a/src/app/api/bus/route.ts b/src/app/api/bus/route.ts index 86e87c7d..84b32600 100644 --- a/src/app/api/bus/route.ts +++ b/src/app/api/bus/route.ts @@ -5,10 +5,10 @@ import { handleRouteError, jsonResponse, notFound, - parseRouteInput, + parseRouteSearchParams, } from "@/lib/api/helpers"; -import { busQueryResponseSchema } from "@/lib/api/schemas"; import { busQuerySchema } from "@/lib/api/schemas/request-schemas"; +import { busQueryResponseSchema } from "@/lib/api/schemas/response-schemas"; import { resolveApiUserId } from "@/lib/auth/helpers"; export const dynamic = "force-dynamic"; @@ -21,10 +21,8 @@ export const dynamic = "force-dynamic"; */ export async function GET(request: NextRequest) { const searchParams = request.nextUrl.searchParams; - const parsedQuery = parseRouteInput( - { - versionKey: searchParams.get("versionKey") ?? undefined, - }, + const parsedQuery = parseRouteSearchParams( + searchParams, busQuerySchema, "Invalid bus query", { logErrors: true }, diff --git a/src/app/api/calendar-subscriptions/current/route.ts b/src/app/api/calendar-subscriptions/current/route.ts index 8c6262be..3c140b93 100644 --- a/src/app/api/calendar-subscriptions/current/route.ts +++ b/src/app/api/calendar-subscriptions/current/route.ts @@ -1,10 +1,6 @@ import { getUserCalendarSubscription } from "@/features/home/server/subscription-read-model"; -import { - handleRouteError, - jsonResponse, - unauthorized, -} from "@/lib/api/helpers"; -import { resolveApiUserId } from "@/lib/auth/helpers"; +import { handleRouteError, jsonResponse } from "@/lib/api/helpers"; +import { requireAuth } from "@/lib/auth/helpers"; export const dynamic = "force-dynamic"; @@ -15,10 +11,9 @@ export const dynamic = "force-dynamic"; */ export async function GET(request: Request) { try { - const userId = await resolveApiUserId(request); - if (!userId) { - return unauthorized(); - } + const auth = await requireAuth(request); + if (auth instanceof Response) return auth; + const { userId } = auth; const subscription = await getUserCalendarSubscription(userId); diff --git a/src/app/api/calendar-subscriptions/route.ts b/src/app/api/calendar-subscriptions/route.ts index 1a79b56e..3b725259 100644 --- a/src/app/api/calendar-subscriptions/route.ts +++ b/src/app/api/calendar-subscriptions/route.ts @@ -3,10 +3,9 @@ import { handleRouteError, jsonResponse, parseRouteJsonBody, - unauthorized, } from "@/lib/api/helpers"; import { calendarSubscriptionCreateRequestSchema } from "@/lib/api/schemas/request-schemas"; -import { resolveApiUserId } from "@/lib/auth/helpers"; +import { requireAuth } from "@/lib/auth/helpers"; export const dynamic = "force-dynamic"; @@ -18,10 +17,9 @@ export const dynamic = "force-dynamic"; */ export async function POST(request: Request) { try { - const userId = await resolveApiUserId(request); - if (!userId) { - return unauthorized(); - } + const auth = await requireAuth(request); + if (auth instanceof Response) return auth; + const { userId } = auth; const parsedBody = await parseRouteJsonBody( request, diff --git a/src/app/api/comments/[id]/reactions/route.ts b/src/app/api/comments/[id]/reactions/route.ts index d50fa8db..7cff3bc0 100644 --- a/src/app/api/comments/[id]/reactions/route.ts +++ b/src/app/api/comments/[id]/reactions/route.ts @@ -2,35 +2,16 @@ import { handleRouteError, jsonResponse, notFound, - parseRouteInput, + parseResourceIdParam, parseRouteJsonBody, - parseRouteParams, - unauthorized, + parseRouteSearchParams, } from "@/lib/api/helpers"; -import { - commentReactionRequestSchema, - resourceIdPathParamsSchema, -} from "@/lib/api/schemas/request-schemas"; -import { requireWriteAuth, resolveApiUserId } from "@/lib/auth/helpers"; +import { commentReactionRequestSchema } from "@/lib/api/schemas/request-schemas"; +import { requireAuth, requireWriteAuth } from "@/lib/auth/helpers"; import { prisma } from "@/lib/db/prisma"; export const dynamic = "force-dynamic"; -async function parseCommentId( - params: Promise<{ id: string }>, -): Promise { - const parsed = await parseRouteParams( - params, - resourceIdPathParamsSchema, - "Invalid comment ID", - ); - if (parsed instanceof Response) { - return parsed; - } - - return parsed.id; -} - /** * Add one reaction to a comment. * @pathParams resourceIdPathParamsSchema @@ -47,7 +28,7 @@ export async function POST( } const { userId } = auth; - const parsed = await parseCommentId(params); + const parsed = await parseResourceIdParam(params, "comment"); if (parsed instanceof Response) { return parsed; } @@ -104,21 +85,18 @@ export async function DELETE( request: Request, { params }: { params: Promise<{ id: string }> }, ) { - const userId = await resolveApiUserId(request); - if (!userId) { - return unauthorized(); - } + const auth = await requireAuth(request); + if (auth instanceof Response) return auth; + const { userId } = auth; - const parsed = await parseCommentId(params); + const parsed = await parseResourceIdParam(params, "comment"); if (parsed instanceof Response) { return parsed; } const id = parsed; const { searchParams } = new URL(request.url); - const parsedBody = parseRouteInput( - { - type: searchParams.get("type"), - }, + const parsedBody = parseRouteSearchParams( + searchParams, commentReactionRequestSchema, "Invalid reaction", { logErrors: true }, diff --git a/src/app/api/comments/[id]/route.ts b/src/app/api/comments/[id]/route.ts index 41b4f426..4912f748 100644 --- a/src/app/api/comments/[id]/route.ts +++ b/src/app/api/comments/[id]/route.ts @@ -1,4 +1,3 @@ -import { NextResponse } from "next/server"; import { buildCommentNodes, type CommentNode, @@ -9,17 +8,14 @@ import { handleRouteError, jsonResponse, notFound, - parseRouteInput, + parseResourceIdParam, parseRouteJsonBody, unauthorized, } from "@/lib/api/helpers"; +import { commentUpdateRequestSchema } from "@/lib/api/schemas/request-schemas"; import { - commentUpdateRequestSchema, - resourceIdPathParamsSchema, -} from "@/lib/api/schemas/request-schemas"; -import { + fireAuditLog, getAuditRequestMetadata, - writeAuditLog, } from "@/lib/audit/write-audit-log"; import { resolveApiUserId } from "@/lib/auth/helpers"; import { getViewerContext } from "@/lib/auth/viewer-context"; @@ -36,22 +32,6 @@ function findComment(nodes: CommentNode[], id: string): CommentNode | null { return null; } -async function parseCommentId( - params: Promise<{ id: string }>, -): Promise { - const raw = await params; - const parsed = parseRouteInput( - raw, - resourceIdPathParamsSchema, - "Invalid comment ID", - ); - if (parsed instanceof Response) { - return badRequest("Invalid comment ID"); - } - - return parsed.id; -} - /** * Get one comment thread by comment ID. * @pathParams resourceIdPathParamsSchema @@ -62,70 +42,74 @@ export async function GET( request: Request, { params }: { params: Promise<{ id: string }> }, ) { - const parsed = await parseCommentId(params); - if (parsed instanceof NextResponse) { + const parsed = await parseResourceIdParam(params, "comment"); + if (parsed instanceof Response) { return parsed; } const id = parsed; try { - const comment = await prisma.comment.findUnique({ - where: { id }, - select: { - sectionId: true, - courseId: true, - teacherId: true, - sectionTeacherId: true, - rootId: true, - id: true, - homework: { - select: { - id: true, - title: true, - section: { - select: { jwId: true, code: true }, + const viewerUserId = await resolveApiUserId(request); + + // Fetch the anchor comment and the viewer context in parallel. + const [comment, viewer] = await Promise.all([ + prisma.comment.findUnique({ + where: { id }, + select: { + sectionId: true, + courseId: true, + teacherId: true, + sectionTeacherId: true, + rootId: true, + id: true, + homework: { + select: { + id: true, + title: true, + section: { + select: { jwId: true, code: true }, + }, }, }, - }, - sectionTeacher: { - select: { - sectionId: true, - teacherId: true, - section: { - select: { - jwId: true, - code: true, - course: { - select: { jwId: true, nameCn: true }, + sectionTeacher: { + select: { + sectionId: true, + teacherId: true, + section: { + select: { + jwId: true, + code: true, + course: { + select: { jwId: true, nameCn: true }, + }, }, }, - }, - teacher: { - select: { nameCn: true }, + teacher: { + select: { nameCn: true }, + }, }, }, + section: { + select: { jwId: true, code: true }, + }, + course: { + select: { jwId: true, nameCn: true }, + }, + teacher: { + select: { nameCn: true }, + }, }, - section: { - select: { jwId: true, code: true }, - }, - course: { - select: { jwId: true, nameCn: true }, - }, - teacher: { - select: { nameCn: true }, - }, - }, - }); + }), + getViewerContext({ + includeAdmin: false, + userId: viewerUserId, + }), + ]); if (!comment) { return notFound(); } - const viewerUserId = await resolveApiUserId(request); - const viewer = await getViewerContext({ - includeAdmin: false, - userId: viewerUserId, - }); const threadKey = comment.rootId ?? comment.id; const threadComments = await prisma.comment.findMany({ @@ -221,8 +205,8 @@ export async function PATCH( request: Request, { params }: { params: Promise<{ id: string }> }, ) { - const parsed = await parseCommentId(params); - if (parsed instanceof NextResponse) { + const parsed = await parseResourceIdParam(params, "comment"); + if (parsed instanceof Response) { return parsed; } const id = parsed; @@ -235,16 +219,10 @@ export async function PATCH( return parsedBody; } + // Use schema-parsed values directly — Zod already validated enums/booleans. const content = parsedBody.body; - - const visibility = - typeof parsedBody.visibility === "string" - ? parsedBody.visibility - : undefined; - const isAnonymous = - typeof parsedBody.isAnonymous === "boolean" - ? parsedBody.isAnonymous - : undefined; + const visibility = parsedBody.visibility; + const isAnonymous = parsedBody.isAnonymous; const hasAttachmentUpdate = Array.isArray(parsedBody.attachmentIds); const attachmentIds = hasAttachmentUpdate @@ -367,14 +345,14 @@ export async function PATCH( const { roots } = buildCommentNodes([updatedComment], viewer); - writeAuditLog({ + fireAuditLog({ action: "comment_edit", userId, targetId: id, targetType: "comment", metadata: { body: content?.slice(0, 200) }, ...getAuditRequestMetadata(request), - }).catch(() => {}); + }); return jsonResponse({ success: true, comment: roots[0] }); } catch (error) { @@ -392,8 +370,8 @@ export async function DELETE( request: Request, { params }: { params: Promise<{ id: string }> }, ) { - const parsed = await parseCommentId(params); - if (parsed instanceof NextResponse) { + const parsed = await parseResourceIdParam(params, "comment"); + if (parsed instanceof Response) { return parsed; } const id = parsed; @@ -425,13 +403,13 @@ export async function DELETE( }, }); - writeAuditLog({ + fireAuditLog({ action: "comment_delete", userId, targetId: id, targetType: "comment", ...getAuditRequestMetadata(request), - }).catch(() => {}); + }); return jsonResponse({ success: true }); } catch (error) { diff --git a/src/app/api/comments/route.ts b/src/app/api/comments/route.ts index 81ba71af..4105d7b5 100644 --- a/src/app/api/comments/route.ts +++ b/src/app/api/comments/route.ts @@ -5,16 +5,16 @@ import { handleRouteError, jsonResponse, notFound, - parseRouteInput, parseRouteJsonBody, + parseRouteSearchParams, } from "@/lib/api/helpers"; import { commentCreateRequestSchema, commentsQuerySchema, } from "@/lib/api/schemas/request-schemas"; import { + fireAuditLog, getAuditRequestMetadata, - writeAuditLog, } from "@/lib/audit/write-audit-log"; import { requireWriteAuth, resolveApiUserId } from "@/lib/auth/helpers"; import { getViewerContext } from "@/lib/auth/viewer-context"; @@ -30,13 +30,8 @@ export const dynamic = "force-dynamic"; */ export async function GET(request: Request) { const { searchParams } = new URL(request.url); - const parsedQuery = parseRouteInput( - { - targetType: searchParams.get("targetType"), - targetId: searchParams.get("targetId") ?? undefined, - sectionId: searchParams.get("sectionId") ?? undefined, - teacherId: searchParams.get("teacherId") ?? undefined, - }, + const parsedQuery = parseRouteSearchParams( + searchParams, commentsQuerySchema, "Invalid target", ); @@ -139,6 +134,7 @@ export async function POST(request: Request) { const targetType = parsedBody.targetType; const content = parsedBody.body; + // Use schema-parsed values directly — Zod already validated the enum/boolean. const visibility = parsedBody.visibility ?? "public"; const isAnonymous = parsedBody.isAnonymous === true; @@ -154,10 +150,15 @@ export async function POST(request: Request) { sectionId: parsedBody.sectionId, targetType, teacherId: parsedBody.teacherId, + // Verify target entity exists to prevent orphan comments on deleted targets. + verifyExistence: true, }); if (!target) { return badRequest("Invalid target"); } + if (!target.verified) { + return notFound("Target not found"); + } let parentId: string | null = null; let rootId: string | null = null; @@ -224,14 +225,14 @@ export async function POST(request: Request) { }); } - writeAuditLog({ + fireAuditLog({ action: "comment_create", userId, targetId: comment.id, targetType: "comment", metadata: { body: content.slice(0, 200) }, ...getAuditRequestMetadata(request), - }).catch(() => {}); + }); return jsonResponse({ id: comment.id }); } catch (error) { diff --git a/src/app/api/courses/route.ts b/src/app/api/courses/route.ts index 5dd61464..bb21d48e 100644 --- a/src/app/api/courses/route.ts +++ b/src/app/api/courses/route.ts @@ -1,9 +1,8 @@ import type { NextRequest } from "next/server"; import { - getPagination, handleRouteError, jsonResponse, - parseRouteInput, + parseRouteQuery, } from "@/lib/api/helpers"; import { coursesQuerySchema } from "@/lib/api/schemas/request-schemas"; import { buildCourseListWhere } from "@/lib/course-section-queries"; @@ -19,24 +18,17 @@ export const dynamic = "force-dynamic"; */ export async function GET(request: NextRequest) { const searchParams = request.nextUrl.searchParams; - const parsedQuery = parseRouteInput( - { - search: searchParams.get("search") ?? undefined, - educationLevelId: searchParams.get("educationLevelId") ?? undefined, - categoryId: searchParams.get("categoryId") ?? undefined, - classTypeId: searchParams.get("classTypeId") ?? undefined, - page: searchParams.get("page") ?? undefined, - limit: searchParams.get("limit") ?? undefined, - }, + const parsed = parseRouteQuery( + searchParams, coursesQuerySchema, "Invalid course query", { logErrors: true }, ); - if (parsedQuery instanceof Response) { - return parsedQuery; + if (parsed instanceof Response) { + return parsed; } - const pagination = getPagination(searchParams); + const { query: parsedQuery, pagination } = parsed; const { search, educationLevelId, categoryId, classTypeId } = parsedQuery; const where = buildCourseListWhere({ search, @@ -46,7 +38,11 @@ export async function GET(request: NextRequest) { }); try { - const result = await paginatedCourseQuery(pagination.page, where); + const result = await paginatedCourseQuery( + pagination.page, + pagination.pageSize, + where, + ); return jsonResponse(result); } catch (error) { return handleRouteError("Failed to fetch courses", error); diff --git a/src/app/api/dashboard-links/visit/route.ts b/src/app/api/dashboard-links/visit/route.ts index 90172e63..d441f453 100644 --- a/src/app/api/dashboard-links/visit/route.ts +++ b/src/app/api/dashboard-links/visit/route.ts @@ -10,6 +10,14 @@ import { logAppEvent } from "@/lib/log/app-logger"; export const dynamic = "force-dynamic"; +function resolveVisitTarget( + schema: typeof dashboardLinkVisitQuerySchema, + slug: FormDataEntryValue | string | null, +) { + const parsed = schema.safeParse({ slug }); + return parsed.success ? resolveDashboardLinkBySlug(parsed.data.slug) : null; +} + /** * Redirect to one dashboard link without side effects. * @params dashboardLinkVisitQuerySchema @@ -17,12 +25,10 @@ export const dynamic = "force-dynamic"; */ export async function GET(request: Request) { const { searchParams } = new URL(request.url); - const parsedQuery = dashboardLinkVisitQuerySchema.safeParse({ - slug: searchParams.get("slug"), - }); - const target = parsedQuery.success - ? resolveDashboardLinkBySlug(parsedQuery.data.slug) - : null; + const target = resolveVisitTarget( + dashboardLinkVisitQuerySchema, + searchParams.get("slug"), + ); if (!target) { return NextResponse.redirect(new URL("/", request.url)); @@ -38,12 +44,10 @@ export async function GET(request: Request) { */ export async function POST(request: Request) { const formData = await request.formData(); - const parsedBody = dashboardLinkVisitRequestSchema.safeParse({ - slug: formData.get("slug"), - }); - const target = parsedBody.success - ? resolveDashboardLinkBySlug(parsedBody.data.slug) - : null; + const target = resolveVisitTarget( + dashboardLinkVisitRequestSchema, + formData.get("slug"), + ); if (!target) { return NextResponse.redirect(new URL("/", request.url), 303); diff --git a/src/app/api/descriptions/route.ts b/src/app/api/descriptions/route.ts index 5b586cd1..ede255c7 100644 --- a/src/app/api/descriptions/route.ts +++ b/src/app/api/descriptions/route.ts @@ -8,16 +8,16 @@ import { handleRouteError, jsonResponse, notFound, - parseRouteInput, parseRouteJsonBody, + parseRouteSearchParams, } from "@/lib/api/helpers"; import { descriptionsQuerySchema, descriptionUpsertRequestSchema, } from "@/lib/api/schemas/request-schemas"; import { + fireAuditLog, getAuditRequestMetadata, - writeAuditLog, } from "@/lib/audit/write-audit-log"; import { requireWriteAuth } from "@/lib/auth/helpers"; import { prisma } from "@/lib/db/prisma"; @@ -32,11 +32,8 @@ export const dynamic = "force-dynamic"; */ export async function GET(request: Request) { const { searchParams } = new URL(request.url); - const parsedQuery = parseRouteInput( - { - targetType: searchParams.get("targetType"), - targetId: searchParams.get("targetId") ?? "", - }, + const parsedQuery = parseRouteSearchParams( + searchParams, descriptionsQuerySchema, "Invalid target", ); @@ -139,14 +136,14 @@ export async function POST(request: Request) { }); if (result.updated) { - writeAuditLog({ + fireAuditLog({ action: "description_edit", userId, targetId: result.id, targetType: "description", metadata: { targetType, content: content.slice(0, 200) }, ...getAuditRequestMetadata(request), - }).catch(() => {}); + }); } return jsonResponse({ id: result.id, updated: result.updated }); diff --git a/src/app/api/homeworks/[id]/completion/route.ts b/src/app/api/homeworks/[id]/completion/route.ts index 462eddfc..036936c6 100644 --- a/src/app/api/homeworks/[id]/completion/route.ts +++ b/src/app/api/homeworks/[id]/completion/route.ts @@ -2,34 +2,15 @@ import { handleRouteError, jsonResponse, notFound, + parseResourceIdParam, parseRouteJsonBody, - parseRouteParams, - unauthorized, } from "@/lib/api/helpers"; -import { - homeworkCompletionRequestSchema, - resourceIdPathParamsSchema, -} from "@/lib/api/schemas/request-schemas"; -import { resolveApiUserId } from "@/lib/auth/helpers"; +import { homeworkCompletionRequestSchema } from "@/lib/api/schemas/request-schemas"; +import { requireAuth } from "@/lib/auth/helpers"; import { prisma } from "@/lib/db/prisma"; export const dynamic = "force-dynamic"; -async function parseHomeworkId( - params: Promise<{ id: string }>, -): Promise { - const parsed = await parseRouteParams( - params, - resourceIdPathParamsSchema, - "Invalid homework ID", - ); - if (parsed instanceof Response) { - return parsed; - } - - return parsed.id; -} - /** * Update completion state for one homework. * @pathParams resourceIdPathParamsSchema @@ -41,7 +22,7 @@ export async function PUT( request: Request, { params }: { params: Promise<{ id: string }> }, ) { - const parsed = await parseHomeworkId(params); + const parsed = await parseResourceIdParam(params, "homework"); if (parsed instanceof Response) { return parsed; } @@ -55,10 +36,9 @@ export async function PUT( return parsedBody; } - const userId = await resolveApiUserId(request); - if (!userId) { - return unauthorized(); - } + const auth = await requireAuth(request); + if (auth instanceof Response) return auth; + const { userId } = auth; try { const homework = await prisma.homework.findUnique({ diff --git a/src/app/api/homeworks/[id]/route.ts b/src/app/api/homeworks/[id]/route.ts index fbb112da..46cf2ca8 100644 --- a/src/app/api/homeworks/[id]/route.ts +++ b/src/app/api/homeworks/[id]/route.ts @@ -1,17 +1,14 @@ -import { NextResponse } from "next/server"; +import type { Prisma } from "@/generated/prisma/client"; import { badRequest, forbidden, handleRouteError, jsonResponse, notFound, - parseRouteInput, + parseResourceIdParam, parseRouteJsonBody, } from "@/lib/api/helpers"; -import { - homeworkUpdateRequestSchema, - resourceIdPathParamsSchema, -} from "@/lib/api/schemas/request-schemas"; +import { homeworkUpdateRequestSchema } from "@/lib/api/schemas/request-schemas"; import { requireWriteAuth } from "@/lib/auth/helpers"; import { getViewerContext } from "@/lib/auth/viewer-context"; import { prisma } from "@/lib/db/prisma"; @@ -19,22 +16,6 @@ import { parseDateInput } from "@/lib/time/parse-date-input"; export const dynamic = "force-dynamic"; -async function parseHomeworkId( - params: Promise<{ id: string }>, -): Promise { - const raw = await params; - const parsed = parseRouteInput( - raw, - resourceIdPathParamsSchema, - "Invalid homework ID", - ); - if (parsed instanceof Response) { - return badRequest("Invalid homework ID"); - } - - return parsed.id; -} - /** * Update one homework. * @pathParams resourceIdPathParamsSchema @@ -46,8 +27,8 @@ export async function PATCH( request: Request, { params }: { params: Promise<{ id: string }> }, ) { - const parsed = await parseHomeworkId(params); - if (parsed instanceof NextResponse) { + const parsed = await parseResourceIdParam(params, "homework"); + if (parsed instanceof Response) { return parsed; } const id = parsed; @@ -114,8 +95,8 @@ export async function PATCH( return forbidden("Homework deleted"); } - const updates: Record = { - updatedById: userId, + const updates: Prisma.HomeworkUpdateInput = { + updatedBy: { connect: { id: userId } }, }; if (title !== undefined) updates.title = title; @@ -132,7 +113,17 @@ export async function PATCH( if (submissionDueAt !== undefined) updates.submissionDueAt = submissionDueAt; - if (Object.keys(updates).length === 1) { + // Only count user-provided fields; `updatedBy` is always present. + const userFieldCount = [ + title !== undefined, + parsedBody.isMajor !== undefined, + parsedBody.requiresTeam !== undefined, + hasPublishedAt, + hasSubmissionStartAt, + hasSubmissionDueAt, + ].filter(Boolean).length; + + if (userFieldCount === 0) { return badRequest("No changes"); } @@ -157,8 +148,8 @@ export async function DELETE( request: Request, { params }: { params: Promise<{ id: string }> }, ) { - const parsed = await parseHomeworkId(params); - if (parsed instanceof NextResponse) { + const parsed = await parseResourceIdParam(params, "homework"); + if (parsed instanceof Response) { return parsed; } const id = parsed; diff --git a/src/app/api/homeworks/route.ts b/src/app/api/homeworks/route.ts index e0e118b9..359246f9 100644 --- a/src/app/api/homeworks/route.ts +++ b/src/app/api/homeworks/route.ts @@ -5,8 +5,8 @@ import { jsonResponse, notFound, parseInteger, - parseRouteInput, parseRouteJsonBody, + parseRouteSearchParams, } from "@/lib/api/helpers"; import { homeworkCreateRequestSchema, @@ -16,6 +16,7 @@ import { requireWriteAuth, resolveApiUserId } from "@/lib/auth/helpers"; import { getViewerContext } from "@/lib/auth/viewer-context"; import { getPrisma, prisma } from "@/lib/db/prisma"; import { parseDateInput } from "@/lib/time/parse-date-input"; + export const dynamic = "force-dynamic"; /** @@ -26,12 +27,8 @@ export const dynamic = "force-dynamic"; */ export async function GET(request: Request) { const { searchParams } = new URL(request.url); - const parsedQuery = parseRouteInput( - { - sectionId: searchParams.get("sectionId") ?? undefined, - sectionIds: searchParams.get("sectionIds") ?? undefined, - includeDeleted: searchParams.get("includeDeleted") ?? undefined, - }, + const parsedQuery = parseRouteSearchParams( + searchParams, homeworksQuerySchema, "Invalid homework query", { logErrors: true }, @@ -85,6 +82,11 @@ export async function GET(request: Request) { deletedBy: { select: { id: true, name: true, username: true, image: true }, }, + _count: { + select: { + comments: { where: { status: { not: "deleted" } } }, + }, + }, ...(viewer.userId ? { homeworkCompletions: { @@ -118,30 +120,13 @@ export async function GET(request: Request) { take: 50, }), ]); - const homeworkIds = homeworks.map((homework) => homework.id); - const commentCountRows = - homeworkIds.length > 0 - ? await prisma.comment.groupBy({ - by: ["homeworkId"], - where: { - homeworkId: { in: homeworkIds }, - status: { not: "deleted" }, - }, - _count: { _all: true }, - }) - : []; - const commentCounts = new Map( - commentCountRows.flatMap((row) => - row.homeworkId ? [[row.homeworkId, row._count._all] as const] : [], - ), - ); const responseHomeworks = homeworks.map((homework) => { - const { homeworkCompletions, ...rest } = homework; + const { homeworkCompletions, _count, ...rest } = homework; return { ...rest, completion: homeworkCompletions?.[0] ?? null, - commentCount: commentCounts.get(homework.id) ?? 0, + commentCount: _count.comments, }; }); diff --git a/src/app/api/me/route.ts b/src/app/api/me/route.ts index a9510e6a..b51f49dd 100644 --- a/src/app/api/me/route.ts +++ b/src/app/api/me/route.ts @@ -1,10 +1,5 @@ -import { - handleRouteError, - jsonResponse, - notFound, - unauthorized, -} from "@/lib/api/helpers"; -import { resolveApiUserId } from "@/lib/auth/helpers"; +import { handleRouteError, jsonResponse, notFound } from "@/lib/api/helpers"; +import { requireAuth } from "@/lib/auth/helpers"; import { prisma } from "@/lib/db/prisma"; export const dynamic = "force-dynamic"; @@ -16,10 +11,9 @@ export const dynamic = "force-dynamic"; */ export async function GET(request: Request) { try { - const userId = await resolveApiUserId(request); - if (!userId) { - return unauthorized(); - } + const auth = await requireAuth(request); + if (auth instanceof Response) return auth; + const { userId } = auth; const user = await prisma.user.findUnique({ where: { id: userId }, diff --git a/src/app/api/me/subscriptions/homeworks/route.ts b/src/app/api/me/subscriptions/homeworks/route.ts index 86e8958b..52ef4103 100644 --- a/src/app/api/me/subscriptions/homeworks/route.ts +++ b/src/app/api/me/subscriptions/homeworks/route.ts @@ -4,12 +4,8 @@ import { listSubscribedHomeworks, } from "@/features/home/server/subscription-read-model"; import { withHomeworkItemState } from "@/features/homeworks/server/homework-item-state"; -import { - handleRouteError, - jsonResponse, - unauthorized, -} from "@/lib/api/helpers"; -import { resolveApiUserId } from "@/lib/auth/helpers"; +import { handleRouteError, jsonResponse } from "@/lib/api/helpers"; +import { requireAuth } from "@/lib/auth/helpers"; import { getViewerContext } from "@/lib/auth/viewer-context"; export const dynamic = "force-dynamic"; @@ -20,10 +16,9 @@ export const dynamic = "force-dynamic"; * @response 401:openApiErrorSchema */ export async function GET(request: Request) { - const userId = await resolveApiUserId(request); - if (!userId) { - return unauthorized(); - } + const auth = await requireAuth(request); + if (auth instanceof Response) return auth; + const { userId } = auth; try { const viewer = await getViewerContext({ diff --git a/src/app/api/schedules/route.ts b/src/app/api/schedules/route.ts index e282321a..174746f2 100644 --- a/src/app/api/schedules/route.ts +++ b/src/app/api/schedules/route.ts @@ -1,10 +1,9 @@ import type { NextRequest } from "next/server"; import { buildPaginatedResponse, - getPagination, handleRouteError, jsonResponse, - parseRouteInput, + parseRouteQuery, } from "@/lib/api/helpers"; import { schedulesQuerySchema } from "@/lib/api/schemas/request-schemas"; import { getPrisma, prisma } from "@/lib/db/prisma"; @@ -16,6 +15,17 @@ import { parseDateInput } from "@/lib/time/parse-date-input"; import { formatTime } from "@/shared/lib/time-utils"; export const dynamic = "force-dynamic"; +function parseScheduleDateParam(name: "dateFrom" | "dateTo", value?: string) { + if (!value) { + return undefined; + } + + const parsed = parseDateInput(value); + return parsed instanceof Date + ? parsed + : handleRouteError("Invalid schedule query", `Invalid ${name}`, 400); +} + /** * List schedules with filters and pagination. * @params schedulesQuerySchema @@ -25,30 +35,17 @@ export const dynamic = "force-dynamic"; export async function GET(request: NextRequest) { try { const searchParams = request.nextUrl.searchParams; - const parsedQuery = parseRouteInput( - { - sectionId: searchParams.get("sectionId") ?? undefined, - sectionJwId: searchParams.get("sectionJwId") ?? undefined, - sectionCode: searchParams.get("sectionCode") ?? undefined, - teacherId: searchParams.get("teacherId") ?? undefined, - teacherCode: searchParams.get("teacherCode") ?? undefined, - roomId: searchParams.get("roomId") ?? undefined, - roomJwId: searchParams.get("roomJwId") ?? undefined, - dateFrom: searchParams.get("dateFrom") ?? undefined, - dateTo: searchParams.get("dateTo") ?? undefined, - weekday: searchParams.get("weekday") ?? undefined, - page: searchParams.get("page") ?? undefined, - limit: searchParams.get("limit") ?? undefined, - }, + const parsed = parseRouteQuery( + searchParams, schedulesQuerySchema, "Invalid schedule query", { logErrors: true }, ); - if (parsedQuery instanceof Response) { - return parsedQuery; + if (parsed instanceof Response) { + return parsed; } - const pagination = getPagination(searchParams); + const { query: parsedQuery, pagination } = parsed; const { sectionId, sectionJwId, @@ -62,29 +59,13 @@ export async function GET(request: NextRequest) { weekday, } = parsedQuery; - let parsedDateFrom: Date | undefined; - if (dateFrom) { - const nextDateFrom = parseDateInput(dateFrom); - if (!(nextDateFrom instanceof Date)) { - return handleRouteError( - "Invalid schedule query", - "Invalid dateFrom", - 400, - ); - } - parsedDateFrom = nextDateFrom; + const parsedDateFrom = parseScheduleDateParam("dateFrom", dateFrom); + if (parsedDateFrom instanceof Response) { + return parsedDateFrom; } - let parsedDateTo: Date | undefined; - if (dateTo) { - const nextDateTo = parseDateInput(dateTo); - if (!(nextDateTo instanceof Date)) { - return handleRouteError( - "Invalid schedule query", - "Invalid dateTo", - 400, - ); - } - parsedDateTo = nextDateTo; + const parsedDateTo = parseScheduleDateParam("dateTo", dateTo); + if (parsedDateTo instanceof Response) { + return parsedDateTo; } const whereClause = buildScheduleListWhere({ sectionId, diff --git a/src/app/api/sections/calendar.ics/route.ts b/src/app/api/sections/calendar.ics/route.ts index c4ef5cc5..38ed102f 100644 --- a/src/app/api/sections/calendar.ics/route.ts +++ b/src/app/api/sections/calendar.ics/route.ts @@ -3,7 +3,7 @@ import { badRequest, handleRouteError, parseIntegerList, - parseRouteInput, + parseRouteSearchParams, } from "@/lib/api/helpers"; import { sectionsCalendarQuerySchema } from "@/lib/api/schemas/request-schemas"; import { prisma } from "@/lib/db/prisma"; @@ -19,11 +19,8 @@ export const dynamic = "force-dynamic"; */ export async function GET(request: NextRequest) { try { - const { searchParams } = new URL(request.url); - const parsedQuery = parseRouteInput( - { - sectionIds: searchParams.get("sectionIds") ?? "", - }, + const parsedQuery = parseRouteSearchParams( + request.nextUrl.searchParams, sectionsCalendarQuerySchema, "sectionIds parameter is required", { logErrors: true }, diff --git a/src/app/api/sections/route.ts b/src/app/api/sections/route.ts index 9339248d..4ad5c42b 100644 --- a/src/app/api/sections/route.ts +++ b/src/app/api/sections/route.ts @@ -1,9 +1,8 @@ import type { NextRequest } from "next/server"; import { - getPagination, handleRouteError, jsonResponse, - parseRouteInput, + parseRouteQuery, } from "@/lib/api/helpers"; import { sectionsQuerySchema } from "@/lib/api/schemas/request-schemas"; import { buildSectionListQuery } from "@/lib/course-section-queries"; @@ -19,31 +18,17 @@ export const dynamic = "force-dynamic"; */ export async function GET(request: NextRequest) { const searchParams = request.nextUrl.searchParams; - const parsedQuery = parseRouteInput( - { - courseId: searchParams.get("courseId") ?? undefined, - courseJwId: searchParams.get("courseJwId") ?? undefined, - semesterId: searchParams.get("semesterId") ?? undefined, - semesterJwId: searchParams.get("semesterJwId") ?? undefined, - campusId: searchParams.get("campusId") ?? undefined, - departmentId: searchParams.get("departmentId") ?? undefined, - teacherId: searchParams.get("teacherId") ?? undefined, - teacherCode: searchParams.get("teacherCode") ?? undefined, - search: searchParams.get("search") ?? undefined, - ids: searchParams.get("ids") ?? undefined, - jwIds: searchParams.get("jwIds") ?? undefined, - page: searchParams.get("page") ?? undefined, - limit: searchParams.get("limit") ?? undefined, - }, + const parsed = parseRouteQuery( + searchParams, sectionsQuerySchema, "Invalid section query", { logErrors: true }, ); - if (parsedQuery instanceof Response) { - return parsedQuery; + if (parsed instanceof Response) { + return parsed; } - const pagination = getPagination(searchParams); + const { query: parsedQuery, pagination } = parsed; const { courseId, courseJwId, @@ -72,7 +57,12 @@ export async function GET(request: NextRequest) { }); try { - const result = await paginatedSectionQuery(pagination.page, where, orderBy); + const result = await paginatedSectionQuery( + pagination.page, + pagination.pageSize, + where, + orderBy, + ); return jsonResponse(result); } catch (error) { return handleRouteError("Failed to fetch sections", error); diff --git a/src/app/api/semesters/route.ts b/src/app/api/semesters/route.ts index 0cbadb41..336c778a 100644 --- a/src/app/api/semesters/route.ts +++ b/src/app/api/semesters/route.ts @@ -1,10 +1,9 @@ import type { NextRequest } from "next/server"; import { buildPaginatedResponse, - getPagination, handleRouteError, jsonResponse, - parseRouteInput, + parseRouteQuery, } from "@/lib/api/helpers"; import { semestersQuerySchema } from "@/lib/api/schemas/request-schemas"; import { prisma } from "@/lib/db/prisma"; @@ -20,19 +19,16 @@ export const dynamic = "force-dynamic"; export async function GET(request: NextRequest) { try { const searchParams = request.nextUrl.searchParams; - const parsedQuery = parseRouteInput( - { - page: searchParams.get("page") ?? undefined, - limit: searchParams.get("limit") ?? undefined, - }, + const parsed = parseRouteQuery( + searchParams, semestersQuerySchema, "Invalid semester query", { logErrors: true }, ); - if (parsedQuery instanceof Response) { - return parsedQuery; + if (parsed instanceof Response) { + return parsed; } - const { page, pageSize, skip } = getPagination(searchParams); + const { page, pageSize, skip } = parsed.pagination; const [semesters, total] = await Promise.all([ prisma.semester.findMany({ diff --git a/src/app/api/teachers/route.ts b/src/app/api/teachers/route.ts index e65d3d08..779b2346 100644 --- a/src/app/api/teachers/route.ts +++ b/src/app/api/teachers/route.ts @@ -1,11 +1,10 @@ import type { NextRequest } from "next/server"; import type { Prisma } from "@/generated/prisma/client"; import { - getPagination, handleRouteError, jsonResponse, parseInteger, - parseRouteInput, + parseRouteQuery, } from "@/lib/api/helpers"; import { teachersQuerySchema } from "@/lib/api/schemas/request-schemas"; import { ilike, paginatedTeacherQuery } from "@/lib/query-helpers"; @@ -20,22 +19,17 @@ export const dynamic = "force-dynamic"; */ export async function GET(request: NextRequest) { const searchParams = request.nextUrl.searchParams; - const parsedQuery = parseRouteInput( - { - departmentId: searchParams.get("departmentId") ?? undefined, - search: searchParams.get("search") ?? undefined, - page: searchParams.get("page") ?? undefined, - limit: searchParams.get("limit") ?? undefined, - }, + const parsed = parseRouteQuery( + searchParams, teachersQuerySchema, "Invalid teacher query", { logErrors: true }, ); - if (parsedQuery instanceof Response) { - return parsedQuery; + if (parsed instanceof Response) { + return parsed; } - const pagination = getPagination(searchParams); + const { query: parsedQuery, pagination } = parsed; const { departmentId, search } = parsedQuery; const where: Prisma.TeacherWhereInput = {}; @@ -56,9 +50,14 @@ export async function GET(request: NextRequest) { } try { - const result = await paginatedTeacherQuery(pagination.page, where, { - nameCn: "asc", - }); + const result = await paginatedTeacherQuery( + pagination.page, + pagination.pageSize, + where, + { + nameCn: "asc", + }, + ); return jsonResponse(result); } catch (error) { return handleRouteError("Failed to fetch teachers", error); diff --git a/src/app/api/todos/[id]/route.ts b/src/app/api/todos/[id]/route.ts index eea44f94..4a4ba75c 100644 --- a/src/app/api/todos/[id]/route.ts +++ b/src/app/api/todos/[id]/route.ts @@ -1,37 +1,20 @@ +import type { Prisma } from "@/generated/prisma/client"; import { badRequest, forbidden, handleRouteError, jsonResponse, notFound, + parseResourceIdParam, parseRouteJsonBody, - parseRouteParams, - unauthorized, } from "@/lib/api/helpers"; -import { - resourceIdPathParamsSchema, - todoUpdateRequestSchema, -} from "@/lib/api/schemas/request-schemas"; -import { resolveApiUserId } from "@/lib/auth/helpers"; +import { todoUpdateRequestSchema } from "@/lib/api/schemas/request-schemas"; +import { requireAuth } from "@/lib/auth/helpers"; import { prisma } from "@/lib/db/prisma"; import { parseDateInput } from "@/lib/time/parse-date-input"; export const dynamic = "force-dynamic"; -async function parseTodoId( - params: Promise<{ id: string }>, -): Promise { - const parsed = await parseRouteParams( - params, - resourceIdPathParamsSchema, - "Invalid todo ID", - ); - if (parsed instanceof Response) { - return parsed; - } - return parsed.id; -} - /** * Update one todo. * @pathParams resourceIdPathParamsSchema @@ -43,16 +26,15 @@ export async function PATCH( request: Request, { params }: { params: Promise<{ id: string }> }, ) { - const parsed = await parseTodoId(params); + const parsed = await parseResourceIdParam(params, "todo"); if (parsed instanceof Response) { return parsed; } const id = parsed; - const userId = await resolveApiUserId(request); - if (!userId) { - return unauthorized(); - } + const auth = await requireAuth(request); + if (auth instanceof Response) return auth; + const { userId } = auth; const parsedBody = await parseRouteJsonBody( request, @@ -83,7 +65,7 @@ export async function PATCH( return forbidden(); } - const updates: Record = {}; + const updates: Prisma.TodoUpdateInput = {}; if (parsedBody.title !== undefined) updates.title = parsedBody.title; if (Object.hasOwn(parsedBody, "content")) { updates.content = parsedBody.content?.trim() || null; @@ -116,16 +98,15 @@ export async function DELETE( request: Request, { params }: { params: Promise<{ id: string }> }, ) { - const parsed = await parseTodoId(params); + const parsed = await parseResourceIdParam(params, "todo"); if (parsed instanceof Response) { return parsed; } const id = parsed; - const userId = await resolveApiUserId(request); - if (!userId) { - return unauthorized(); - } + const auth = await requireAuth(request); + if (auth instanceof Response) return auth; + const { userId } = auth; try { const todo = await prisma.todo.findUnique({ diff --git a/src/app/api/todos/route.ts b/src/app/api/todos/route.ts index 851f5a45..d6f5c119 100644 --- a/src/app/api/todos/route.ts +++ b/src/app/api/todos/route.ts @@ -1,18 +1,19 @@ +import type { Prisma } from "@/generated/prisma/client"; import { badRequest, handleRouteError, jsonResponse, - parseRouteInput, parseRouteJsonBody, - unauthorized, + parseRouteSearchParams, } from "@/lib/api/helpers"; import { todoCreateRequestSchema, todosQuerySchema, } from "@/lib/api/schemas/request-schemas"; -import { resolveApiUserId } from "@/lib/auth/helpers"; +import { requireAuth } from "@/lib/auth/helpers"; import { prisma } from "@/lib/db/prisma"; import { parseDateInput } from "@/lib/time/parse-date-input"; + export const dynamic = "force-dynamic"; /** @@ -22,19 +23,13 @@ export const dynamic = "force-dynamic"; * @response 401:openApiErrorSchema */ export async function GET(request: Request) { - const userId = await resolveApiUserId(request); - if (!userId) { - return unauthorized(); - } + const auth = await requireAuth(request); + if (auth instanceof Response) return auth; + const { userId } = auth; const { searchParams } = new URL(request.url); - const parsedQuery = parseRouteInput( - { - completed: searchParams.get("completed") ?? undefined, - priority: searchParams.get("priority") ?? undefined, - dueBefore: searchParams.get("dueBefore") ?? undefined, - dueAfter: searchParams.get("dueAfter") ?? undefined, - }, + const parsedQuery = parseRouteSearchParams( + searchParams, todosQuerySchema, "Invalid todo query", { logErrors: true }, @@ -43,14 +38,20 @@ export async function GET(request: Request) { return parsedQuery; } - const where: Record = { userId }; + const where: Prisma.TodoWhereInput = { userId }; if (parsedQuery.completed === "true") where.completed = true; else if (parsedQuery.completed === "false") where.completed = false; if (parsedQuery.priority) where.priority = parsedQuery.priority; if (parsedQuery.dueBefore || parsedQuery.dueAfter) { - const dueAtFilter: Record = {}; - if (parsedQuery.dueBefore) dueAtFilter.lt = new Date(parsedQuery.dueBefore); - if (parsedQuery.dueAfter) dueAtFilter.gte = new Date(parsedQuery.dueAfter); + const dueAtFilter: Prisma.TodoWhereInput["dueAt"] = {}; + if (parsedQuery.dueBefore) { + const parsed = parseDateInput(parsedQuery.dueBefore); + if (parsed) dueAtFilter.lt = parsed; + } + if (parsedQuery.dueAfter) { + const parsed = parseDateInput(parsedQuery.dueAfter); + if (parsed) dueAtFilter.gte = parsed; + } where.dueAt = dueAtFilter; } @@ -73,10 +74,9 @@ export async function GET(request: Request) { * @response 400:openApiErrorSchema */ export async function POST(request: Request) { - const userId = await resolveApiUserId(request); - if (!userId) { - return unauthorized(); - } + const auth = await requireAuth(request); + if (auth instanceof Response) return auth; + const { userId } = auth; const parsedBody = await parseRouteJsonBody( request, @@ -87,9 +87,13 @@ export async function POST(request: Request) { return parsedBody; } - const dueAt = parseDateInput(parsedBody.dueAt); - if (dueAt === undefined) { - return badRequest("Invalid due date"); + const dueAtRaw = parsedBody.dueAt; + let dueAt: Date | null | undefined; + if (dueAtRaw !== undefined) { + dueAt = parseDateInput(dueAtRaw); + if (dueAt === undefined) { + return badRequest("Invalid due date"); + } } try { @@ -99,7 +103,7 @@ export async function POST(request: Request) { title: parsedBody.title, content: parsedBody.content?.trim() || null, priority: parsedBody.priority ?? "medium", - dueAt, + ...(dueAt !== undefined && { dueAt }), }, }); diff --git a/src/app/api/uploads/[id]/download/route.ts b/src/app/api/uploads/[id]/download/route.ts index 3271fee1..b698a18d 100644 --- a/src/app/api/uploads/[id]/download/route.ts +++ b/src/app/api/uploads/[id]/download/route.ts @@ -1,38 +1,17 @@ import { GetObjectCommand } from "@aws-sdk/client-s3"; import { NextResponse } from "next/server"; +import { buildContentDisposition } from "@/features/uploads/lib/upload-utils"; import { handleRouteError, notFound, - parseRouteParams, - unauthorized, + parseResourceIdParam, } from "@/lib/api/helpers"; -import { resourceIdPathParamsSchema } from "@/lib/api/schemas/request-schemas"; -import { resolveApiUserId } from "@/lib/auth/helpers"; +import { requireAuth } from "@/lib/auth/helpers"; import { prisma } from "@/lib/db/prisma"; import { getS3Bucket, getS3SignedUrl } from "@/lib/storage/s3"; export const dynamic = "force-dynamic"; -async function parseUploadId( - params: Promise<{ id: string }>, -): Promise { - const parsed = await parseRouteParams( - params, - resourceIdPathParamsSchema, - "Invalid upload ID", - ); - if (parsed instanceof Response) { - return parsed; - } - - return parsed.id; -} - -function buildContentDisposition(filename: string) { - const safeName = filename.replace(/"/g, "'"); - return `attachment; filename="${safeName}"`; -} - /** * Redirect to signed URL for one upload. * @pathParams resourceIdPathParamsSchema @@ -44,12 +23,11 @@ export async function GET( request: Request, context: { params: Promise<{ id: string }> }, ) { - const userId = await resolveApiUserId(request); - if (!userId) { - return unauthorized(); - } + const auth = await requireAuth(request); + if (auth instanceof Response) return auth; + const { userId } = auth; - const parsed = await parseUploadId(context.params); + const parsed = await parseResourceIdParam(context.params, "upload"); if (parsed instanceof Response) { return parsed; } diff --git a/src/app/api/uploads/[id]/route.ts b/src/app/api/uploads/[id]/route.ts index 97c7bcd4..45fe3e2e 100644 --- a/src/app/api/uploads/[id]/route.ts +++ b/src/app/api/uploads/[id]/route.ts @@ -1,46 +1,24 @@ import { DeleteObjectCommand } from "@aws-sdk/client-s3"; +import { sanitizeFilename } from "@/features/uploads/lib/upload-utils"; import { badRequest, handleRouteError, jsonResponse, notFound, + parseResourceIdParam, parseRouteJsonBody, - parseRouteParams, - unauthorized, } from "@/lib/api/helpers"; +import { uploadRenameRequestSchema } from "@/lib/api/schemas/request-schemas"; import { - resourceIdPathParamsSchema, - uploadRenameRequestSchema, -} from "@/lib/api/schemas/request-schemas"; -import { + fireAuditLog, getAuditRequestMetadata, - writeAuditLog, } from "@/lib/audit/write-audit-log"; -import { resolveApiUserId } from "@/lib/auth/helpers"; +import { requireAuth } from "@/lib/auth/helpers"; import { prisma } from "@/lib/db/prisma"; import { getS3Bucket, sendS3 } from "@/lib/storage/s3"; export const dynamic = "force-dynamic"; -async function parseUploadId( - params: Promise<{ id: string }>, -): Promise { - const parsed = await parseRouteParams( - params, - resourceIdPathParamsSchema, - "Invalid upload ID", - ); - if (parsed instanceof Response) { - return parsed; - } - - return parsed.id; -} - -function sanitizeFilename(filename: string) { - return filename.trim(); -} - /** * Rename one upload. * @pathParams resourceIdPathParamsSchema @@ -54,12 +32,11 @@ export async function PATCH( request: Request, context: { params: Promise<{ id: string }> }, ) { - const userId = await resolveApiUserId(request); - if (!userId) { - return unauthorized(); - } + const auth = await requireAuth(request); + if (auth instanceof Response) return auth; + const { userId } = auth; - const parsed = await parseUploadId(context.params); + const parsed = await parseResourceIdParam(context.params, "upload"); if (parsed instanceof Response) { return parsed; } @@ -125,12 +102,11 @@ export async function DELETE( request: Request, context: { params: Promise<{ id: string }> }, ) { - const userId = await resolveApiUserId(request); - if (!userId) { - return unauthorized(); - } + const auth = await requireAuth(request); + if (auth instanceof Response) return auth; + const { userId } = auth; - const parsed = await parseUploadId(context.params); + const parsed = await parseResourceIdParam(context.params, "upload"); if (parsed instanceof Response) { return parsed; } @@ -146,20 +122,32 @@ export async function DELETE( return notFound(); } - await sendS3( - new DeleteObjectCommand({ Bucket: getS3Bucket(), Key: upload.key }), - ); - + // Delete DB record first, then S3 object. + // If S3 cleanup fails, the record is gone and the orphaned S3 object + // is harmless (no DB reference points to it). A reverse order would + // leave a DB record pointing at a missing S3 object. await prisma.upload.delete({ where: { id: upload.id } }); - writeAuditLog({ + try { + await sendS3( + new DeleteObjectCommand({ Bucket: getS3Bucket(), Key: upload.key }), + ); + } catch (s3Error) { + // S3 cleanup failure is non-critical — the DB record is already gone. + handleRouteError( + "S3 object cleanup failed after upload deletion", + s3Error, + ); + } + + fireAuditLog({ action: "upload_delete", userId, targetId: upload.id, targetType: "upload", metadata: { key: upload.key, size: upload.size }, ...getAuditRequestMetadata(request), - }).catch(() => {}); + }); return jsonResponse({ deletedId: upload.id, diff --git a/src/app/api/uploads/complete/route.ts b/src/app/api/uploads/complete/route.ts index 8839edda..011f742e 100644 --- a/src/app/api/uploads/complete/route.ts +++ b/src/app/api/uploads/complete/route.ts @@ -4,28 +4,21 @@ import { runUploadSerializableTransaction, UploadError, } from "@/features/uploads/lib/upload-quota"; +import { normalizeContentType } from "@/features/uploads/lib/upload-utils"; import { badRequest, forbidden, handleRouteError, jsonResponse, parseRouteJsonBody, - payloadTooLarge, - unauthorized, } from "@/lib/api/helpers"; import { uploadCompleteRequestSchema } from "@/lib/api/schemas/request-schemas"; -import { resolveApiUserId } from "@/lib/auth/helpers"; +import { requireAuth } from "@/lib/auth/helpers"; import { prisma } from "@/lib/db/prisma"; import { getS3Bucket, sendS3 } from "@/lib/storage/s3"; export const dynamic = "force-dynamic"; -function normalizeContentType(value: unknown) { - if (typeof value !== "string") return null; - const trimmed = value.trim(); - return trimmed.length > 0 ? trimmed : null; -} - /** * Finalize one upload after S3 put. * @body uploadCompleteRequestSchema @@ -33,10 +26,9 @@ function normalizeContentType(value: unknown) { * @response 400:openApiErrorSchema */ export async function POST(request: Request) { - const userId = await resolveApiUserId(request); - if (!userId) { - return unauthorized(); - } + const auth = await requireAuth(request); + if (auth instanceof Response) return auth; + const { userId } = auth; const parsedBody = await parseRouteJsonBody( request, @@ -94,31 +86,7 @@ export async function POST(request: Request) { }); } - const pendingCheck = await prisma.uploadPending.findUnique({ - where: { key }, - }); - if (!pendingCheck || pendingCheck.userId !== userId) { - return badRequest("Upload session expired"); - } - - const bucket = getS3Bucket(); - const head = await sendS3( - new HeadObjectCommand({ Bucket: bucket, Key: key }), - ); - - const size = head.ContentLength ?? 0; - if (!size || size <= 0) { - return badRequest("Uploaded object missing"); - } - - if (size > uploadConfig.maxFileSizeBytes) { - await sendS3(new DeleteObjectCommand({ Bucket: bucket, Key: key })); - return payloadTooLarge("File too large"); - } - - const contentType = - normalizeContentType(parsedBody.contentType) ?? head.ContentType; - + // Move all checks inside the serializable transaction for consistency. const reservation = await runUploadSerializableTransaction(async (tx) => { const pending = await tx.uploadPending.findUnique({ where: { key } }); if (!pending || pending.userId !== userId) { @@ -130,6 +98,27 @@ export async function POST(request: Request) { throw new UploadError("Upload session expired"); } + // S3 HeadObject check must happen outside the transaction boundary + // since it's an external service call — we verify the file exists + // before committing quota changes. + const bucket = getS3Bucket(); + const head = await sendS3( + new HeadObjectCommand({ Bucket: bucket, Key: key }), + ); + + const size = head.ContentLength ?? 0; + if (!size || size <= 0) { + throw new UploadError("Uploaded object missing"); + } + + if (size > uploadConfig.maxFileSizeBytes) { + await sendS3(new DeleteObjectCommand({ Bucket: bucket, Key: key })); + throw new UploadError("File too large"); + } + + const contentType = + normalizeContentType(parsedBody.contentType) ?? head.ContentType; + const [usage, pendingUsage] = await Promise.all([ tx.upload.aggregate({ where: { userId }, diff --git a/src/app/api/uploads/route.ts b/src/app/api/uploads/route.ts index 56d136da..8743f697 100644 --- a/src/app/api/uploads/route.ts +++ b/src/app/api/uploads/route.ts @@ -4,6 +4,7 @@ import { runUploadSerializableTransaction, UploadError, } from "@/features/uploads/lib/upload-quota"; +import { normalizeContentType } from "@/features/uploads/lib/upload-utils"; import { badRequest, handleRouteError, @@ -11,10 +12,9 @@ import { parseInteger, parseRouteJsonBody, payloadTooLarge, - unauthorized, } from "@/lib/api/helpers"; import { uploadCreateRequestSchema } from "@/lib/api/schemas/request-schemas"; -import { resolveApiUserId } from "@/lib/auth/helpers"; +import { requireAuth } from "@/lib/auth/helpers"; import { prisma } from "@/lib/db/prisma"; import { buildUploadKey, getS3Bucket, getS3SignedUrl } from "@/lib/storage/s3"; @@ -26,21 +26,14 @@ function parseFileSize(value: unknown) { return parseInteger(value); } -function normalizeContentType(value: unknown) { - if (typeof value !== "string") return "application/octet-stream"; - const trimmed = value.trim(); - return trimmed.length > 0 ? trimmed : "application/octet-stream"; -} - /** * List uploads of current user. * @response uploadsListResponseSchema */ export async function GET(request: Request) { - const userId = await resolveApiUserId(request); - if (!userId) { - return unauthorized(); - } + const auth = await requireAuth(request); + if (auth instanceof Response) return auth; + const { userId } = auth; try { const now = new Date(); @@ -96,10 +89,9 @@ export async function GET(request: Request) { * @response 400:openApiErrorSchema */ export async function POST(request: Request) { - const userId = await resolveApiUserId(request); - if (!userId) { - return unauthorized(); - } + const auth = await requireAuth(request); + if (auth instanceof Response) return auth; + const { userId } = auth; const parsedBody = await parseRouteJsonBody( request, diff --git a/src/lib/api/client.ts b/src/lib/api/client.ts index 065fb8bc..38ff2f6b 100644 --- a/src/lib/api/client.ts +++ b/src/lib/api/client.ts @@ -138,50 +138,75 @@ export const apiClient = { apiRequest("PATCH", path, options), }; +const API_ERROR_MESSAGE_KEYS = ["error", "message", "detail"] as const; + +function trimmedMessage(value: string | undefined): string | null { + return value?.trim() || null; +} + +function normalizeStringMessage(value: unknown): string | null | undefined { + if (typeof value !== "string" || !value) { + return undefined; + } + return trimmedMessage(value); +} + +function asErrorRecord(value: unknown): Record | null { + return value && typeof value === "object" + ? (value as Record) + : null; +} + +function firstStringMessage( + record: Record, + keys: readonly string[] = API_ERROR_MESSAGE_KEYS, +): string | null | undefined { + for (const key of keys) { + const message = normalizeStringMessage(record[key]); + if (message !== undefined) { + return message; + } + } + + return undefined; +} + +function firstRecordMessage( + value: unknown, + keys: readonly string[] = API_ERROR_MESSAGE_KEYS, +): string | null | undefined { + const record = asErrorRecord(value); + return record ? firstStringMessage(record, keys) : undefined; +} + export function extractApiErrorMessage(errorBody: unknown): string | null { if (typeof errorBody === "string") { - return errorBody.trim() || null; + return trimmedMessage(errorBody); } if (errorBody instanceof Error) { - return errorBody.message?.trim() || null; + return trimmedMessage(errorBody.message); } - if (!errorBody || typeof errorBody !== "object") { + + const anyBody = asErrorRecord(errorBody); + if (!anyBody) { return null; } - const anyBody = errorBody as Record; - - const direct = - (typeof anyBody.error === "string" && anyBody.error) || - (typeof anyBody.message === "string" && anyBody.message) || - (typeof anyBody.detail === "string" && anyBody.detail); - if (direct) { - return direct.trim() || null; + const direct = firstStringMessage(anyBody); + if (direct !== undefined) { + return direct; } - const nestedError = anyBody.error; - if (nestedError && typeof nestedError === "object") { - const nested = nestedError as Record; - const nestedDirect = - (typeof nested.error === "string" && nested.error) || - (typeof nested.message === "string" && nested.message) || - (typeof nested.detail === "string" && nested.detail); - if (nestedDirect) { - return nestedDirect.trim() || null; - } + const nestedDirect = firstRecordMessage(anyBody.error); + if (nestedDirect !== undefined) { + return nestedDirect; } const errors = anyBody.errors; if (Array.isArray(errors) && errors.length > 0) { - const first = errors[0]; - if (first && typeof first === "object") { - const firstObj = first as Record; - const firstMessage = - (typeof firstObj.message === "string" && firstObj.message) || - (typeof firstObj.error === "string" && firstObj.error); - if (firstMessage) { - return firstMessage.trim() || null; - } + const firstMessage = firstRecordMessage(errors[0], ["message", "error"]); + if (firstMessage !== undefined) { + return firstMessage; } } diff --git a/src/lib/api/helpers.ts b/src/lib/api/helpers.ts index 3af0a443..15f7c898 100644 --- a/src/lib/api/helpers.ts +++ b/src/lib/api/helpers.ts @@ -1,6 +1,7 @@ import type { NextRequest } from "next/server"; import { NextResponse } from "next/server"; import type { z } from "zod"; +import { resourceIdPathParamsSchema } from "@/lib/api/schemas/request-schemas"; import { logRouteFailure } from "@/lib/log/app-logger"; import { serializeDatesDeep } from "@/lib/time/serialize-date-output"; import { parseInteger, parseIntegerList } from "./request-integers"; @@ -15,7 +16,7 @@ export type PaginatedResponse = { }; }; -export type PaginationInput = { +type PaginationInput = { page?: number | string | null; pageSize?: number | string | null; defaultPage?: number; @@ -23,7 +24,7 @@ export type PaginationInput = { maxPageSize?: number; }; -type GetPaginationOptions = Pick< +type PaginationOptions = Pick< PaginationInput, "defaultPage" | "defaultPageSize" | "maxPageSize" > & { @@ -35,6 +36,10 @@ type ParseRouteOptions = { logErrors?: boolean; }; +type ParseRouteQueryOptions = ParseRouteOptions & { + pagination?: PaginationOptions; +}; + /** * Normalize pagination values from query params or parsed route input. */ @@ -51,22 +56,6 @@ export function normalizePagination(input: PaginationInput = {}) { }; } -/** - * Parse pagination from URL search params using the shared normalizer. - */ -export function getPagination( - searchParams: URLSearchParams, - options: GetPaginationOptions = {}, -) { - return normalizePagination({ - page: searchParams.get(options.pageParam ?? "page"), - pageSize: searchParams.get(options.pageSizeParam ?? "limit"), - defaultPage: options.defaultPage, - defaultPageSize: options.defaultPageSize, - maxPageSize: options.maxPageSize ?? 100, - }); -} - /** * Build a standard `{ data, pagination }` API response envelope. */ @@ -169,6 +158,60 @@ export function parseRouteInput( return parsed.data; } +function searchParamsInput( + searchParams: URLSearchParams, + schema: TSchema, +) { + return Object.fromEntries( + Object.keys(schema.shape).map((key) => [ + key, + searchParams.get(key) ?? undefined, + ]), + ); +} + +export function parseRouteSearchParams( + searchParams: URLSearchParams, + schema: TSchema, + message: string, + options?: ParseRouteOptions, +): z.output | Response { + return parseRouteInput( + searchParamsInput(searchParams, schema), + schema, + message, + options, + ); +} + +export function parseRouteQuery( + searchParams: URLSearchParams, + schema: TSchema, + message: string, + options?: ParseRouteQueryOptions, +): + | { + query: z.output; + pagination: ReturnType; + } + | Response { + const query = parseRouteSearchParams(searchParams, schema, message, options); + if (query instanceof Response) { + return query; + } + + return { + query, + pagination: normalizePagination({ + page: searchParams.get(options?.pagination?.pageParam ?? "page"), + pageSize: searchParams.get(options?.pagination?.pageSizeParam ?? "limit"), + defaultPage: options?.pagination?.defaultPage, + defaultPageSize: options?.pagination?.defaultPageSize, + maxPageSize: options?.pagination?.maxPageSize, + }), + }; +} + export async function parseRouteParams( params: Promise, schema: TSchema, @@ -201,4 +244,24 @@ export async function parseRouteJsonBody( }); } +/** + * Parse a `[id]` path param using the canonical resourceIdPathParamsSchema. + * Replaces duplicated parseCommentId / parseHomeworkId / parseTodoId / parseUploadId helpers. + */ +export async function parseResourceIdParam( + params: Promise<{ id: string }>, + label: string, +): Promise { + const parsed = await parseRouteParams( + params, + resourceIdPathParamsSchema, + `Invalid ${label} ID`, + ); + if (parsed instanceof Response) { + return parsed; + } + + return parsed.id; +} + export { parseInteger, parseIntegerList }; diff --git a/src/lib/api/schemas.ts b/src/lib/api/schemas.ts deleted file mode 100644 index 4c68c094..00000000 --- a/src/lib/api/schemas.ts +++ /dev/null @@ -1,2 +0,0 @@ -export * from "./schemas/request-schemas"; -export * from "./schemas/response-schemas"; diff --git a/src/lib/course-section-queries.ts b/src/lib/course-section-queries.ts index b1ff2497..2c48470a 100644 --- a/src/lib/course-section-queries.ts +++ b/src/lib/course-section-queries.ts @@ -4,7 +4,11 @@ import { } from "@/lib/course-section-query-filters"; import { findCurrentSemester } from "@/lib/current-semester"; import { getPrisma, prisma } from "@/lib/db/prisma"; -import { findClosestMatches } from "@/lib/fuzzy-match"; +import { + extractCodePrefixes, + findClosestMatches, + normalizeFuzzyValue, +} from "@/lib/fuzzy-match"; import { courseDetailInclude, courseInclude, @@ -124,33 +128,10 @@ export async function findSectionCodeMatches( }); const matchedCodes = sections.map((section) => section.code); - const unmatchedCodes = codes.filter((code) => !matchedCodes.includes(code)); - const normalizedChunks = (value: string) => - value - .trim() - .toUpperCase() - .replace(/\s+/g, "") - .match(/[A-Z0-9]+/g) ?? []; + const matchedCodeSet = new Set(matchedCodes); + const unmatchedCodes = codes.filter((code) => !matchedCodeSet.has(code)); const unmatchedCodePrefixes = new Map( - unmatchedCodes.map((code) => { - const chunks = normalizedChunks(code); - const significantChunks = chunks.filter( - (chunk) => chunk.length >= 4 || (chunk.length >= 3 && /\d/.test(chunk)), - ); - const lookupChunks = - significantChunks.length > 0 ? significantChunks : chunks; - const prefixes = Array.from( - new Set( - lookupChunks - .map((chunk) => - chunk.slice(0, Math.min(6, Math.max(3, chunk.length))), - ) - .filter((chunk) => chunk.length >= 3), - ), - ).slice(0, 6); - - return [code, prefixes] as const; - }), + unmatchedCodes.map((code) => [code, extractCodePrefixes(code)] as const), ); const batchedPrefixes = Array.from( new Set(Array.from(unmatchedCodePrefixes.values()).flat()), @@ -174,7 +155,7 @@ export async function findSectionCodeMatches( const normalizedSemesterCodes: Array<{ code: string; normalized: string }> = semesterCodes.map((code: string) => ({ code, - normalized: code.trim().toUpperCase().replace(/\s+/g, ""), + normalized: normalizeFuzzyValue(code), })); const prefixMatches = new Map( batchedPrefixes.map((prefix) => [ diff --git a/src/lib/course-section-query-filters.ts b/src/lib/course-section-query-filters.ts index d80b11fd..cdf8648b 100644 --- a/src/lib/course-section-query-filters.ts +++ b/src/lib/course-section-query-filters.ts @@ -1,8 +1,13 @@ -import { parseInteger, parseIntegerList } from "@/lib/api/helpers"; +import type { Prisma } from "@/generated/prisma/client"; +import { + applyIntegerFilter, + buildJwIdFilter, + buildRelatedFilter, + type IntegerFilter, + parseIdsFilter, +} from "@/lib/query-filter-helpers"; import { buildSectionSearchWhere, ilike } from "@/lib/query-helpers"; -type IntegerFilter = number | string | null | undefined; - type CourseListFilters = { search?: string | null; educationLevelId?: IntegerFilter; @@ -24,23 +29,14 @@ type SectionListFilters = { search?: string | null; }; -const parseIntegerFilter = (value: IntegerFilter) => parseInteger(value); +// IntegerFilter, parseIdsFilter, applyIntegerFilter, buildRelatedFilter, buildJwIdFilter +// are now exported from @/lib/query-filter-helpers.ts -const parseIdsFilter = (value: number[] | string | null | undefined) => { - if (Array.isArray(value)) { - return value.filter(Number.isInteger); - } - return parseIntegerList(value); -}; - -export function buildCourseListWhere(filters: CourseListFilters) { +export function buildCourseListWhere( + filters: CourseListFilters, +): Prisma.CourseWhereInput | undefined { const { search, educationLevelId, categoryId, classTypeId } = filters; - const where: { - OR?: Array>; - educationLevelId?: number; - categoryId?: number; - classTypeId?: number; - } = {}; + const where: Prisma.CourseWhereInput = {}; if (search) { where.OR = [ @@ -50,26 +46,15 @@ export function buildCourseListWhere(filters: CourseListFilters) { ]; } - const parsedEducationLevelId = parseIntegerFilter(educationLevelId); - if (parsedEducationLevelId !== null) { - where.educationLevelId = parsedEducationLevelId; - } - - const parsedCategoryId = parseIntegerFilter(categoryId); - if (parsedCategoryId !== null) { - where.categoryId = parsedCategoryId; - } - - const parsedClassTypeId = parseIntegerFilter(classTypeId); - if (parsedClassTypeId !== null) { - where.classTypeId = parsedClassTypeId; - } + applyIntegerFilter(where, "educationLevelId", educationLevelId); + applyIntegerFilter(where, "categoryId", categoryId); + applyIntegerFilter(where, "classTypeId", classTypeId); return Object.keys(where).length > 0 ? where : undefined; } export function buildSectionListQuery(filters: SectionListFilters): { - where: Record; + where: Prisma.SectionWhereInput; orderBy?: ReturnType["orderBy"]; } { const { @@ -85,54 +70,25 @@ export function buildSectionListQuery(filters: SectionListFilters): { jwIds, search, } = filters; - const where: Record = {}; + const where: Prisma.SectionWhereInput = {}; - const parsedCourseId = parseIntegerFilter(courseId); - if (parsedCourseId !== null) { - where.courseId = parsedCourseId; - } - const courseFilter: { jwId?: number } = {}; - const parsedCourseJwId = parseIntegerFilter(courseJwId); - if (parsedCourseJwId !== null) { - courseFilter.jwId = parsedCourseJwId; - } - if (Object.keys(courseFilter).length > 0) { + applyIntegerFilter(where, "courseId", courseId); + const courseFilter = buildJwIdFilter(courseJwId); + if (courseFilter) { where.course = courseFilter; } - const parsedSemesterId = parseIntegerFilter(semesterId); - if (parsedSemesterId !== null) { - where.semesterId = parsedSemesterId; - } - const semesterFilter: { jwId?: number } = {}; - const parsedSemesterJwId = parseIntegerFilter(semesterJwId); - if (parsedSemesterJwId !== null) { - semesterFilter.jwId = parsedSemesterJwId; - } - if (Object.keys(semesterFilter).length > 0) { + applyIntegerFilter(where, "semesterId", semesterId); + const semesterFilter = buildJwIdFilter(semesterJwId); + if (semesterFilter) { where.semester = semesterFilter; } - const parsedCampusId = parseIntegerFilter(campusId); - if (parsedCampusId !== null) { - where.campusId = parsedCampusId; - } + applyIntegerFilter(where, "campusId", campusId); + applyIntegerFilter(where, "openDepartmentId", departmentId); - const parsedDepartmentId = parseIntegerFilter(departmentId); - if (parsedDepartmentId !== null) { - where.openDepartmentId = parsedDepartmentId; - } - - const teacherFilter: { id?: number; code?: string } = {}; - const parsedTeacherId = parseIntegerFilter(teacherId); - if (parsedTeacherId !== null) { - teacherFilter.id = parsedTeacherId; - } - const trimmedTeacherCode = teacherCode?.trim(); - if (trimmedTeacherCode) { - teacherFilter.code = trimmedTeacherCode; - } - if (Object.keys(teacherFilter).length > 0) { + const teacherFilter = buildRelatedFilter("id", teacherId, teacherCode); + if (teacherFilter) { where.teachers = { some: teacherFilter, }; diff --git a/src/lib/current-semester.ts b/src/lib/current-semester.ts index 9e01fa60..ec8202c2 100644 --- a/src/lib/current-semester.ts +++ b/src/lib/current-semester.ts @@ -7,28 +7,19 @@ type SemesterWithDateRange = { type SemesterFindFirstDelegate = Pick; -const getSemesterStartTime = (semester: SemesterWithDateRange) => - semester.startDate?.getTime() ?? Number.NEGATIVE_INFINITY; +const startTime = (s: SemesterWithDateRange) => + s.startDate?.getTime() ?? Number.NEGATIVE_INFINITY; -const getSemesterEndTime = (semester: SemesterWithDateRange) => - semester.endDate?.getTime() ?? Number.POSITIVE_INFINITY; +const endTime = (s: SemesterWithDateRange) => + s.endDate?.getTime() ?? Number.POSITIVE_INFINITY; -const hasStarted = (semester: SemesterWithDateRange, referenceDate: Date) => - !semester.startDate || semester.startDate <= referenceDate; - -const hasNotEnded = (semester: SemesterWithDateRange, referenceDate: Date) => - !semester.endDate || semester.endDate >= referenceDate; - -const compareMostSpecificCurrentSemester = < - TSemester extends SemesterWithDateRange, ->( +/** + * Sort comparator: prefer later start, then earlier end (most specific current semester). + */ +const byMostSpecific = ( a: TSemester, b: TSemester, -) => { - const startDiff = getSemesterStartTime(b) - getSemesterStartTime(a); - if (startDiff !== 0) return startDiff; - return getSemesterEndTime(a) - getSemesterEndTime(b); -}; +) => startTime(b) - startTime(a) || endTime(a) - endTime(b); export const buildCurrentSemesterWhere = ( referenceDate: Date, @@ -57,23 +48,22 @@ export const selectCurrentSemesterFromList = < semesters: TSemester[], referenceDate: Date, ): TSemester | null => { + // 1. Prefer a semester currently in session (started and not yet ended) const current = semesters .filter( - (semester) => - hasStarted(semester, referenceDate) && - hasNotEnded(semester, referenceDate), + (s) => + (!s.startDate || s.startDate <= referenceDate) && + (!s.endDate || s.endDate >= referenceDate), ) - .sort(compareMostSpecificCurrentSemester); + .sort(byMostSpecific); if (current[0]) return current[0]; + // 2. Fall back to the nearest upcoming semester const future = semesters - .filter((semester) => !hasStarted(semester, referenceDate)) - .sort((a, b) => { - const startDiff = getSemesterStartTime(a) - getSemesterStartTime(b); - if (startDiff !== 0) return startDiff; - return getSemesterEndTime(a) - getSemesterEndTime(b); - }); + .filter((s) => s.startDate && s.startDate > referenceDate) + .sort((a, b) => startTime(a) - startTime(b) || endTime(a) - endTime(b)); if (future[0]) return future[0]; - return [...semesters].sort(compareMostSpecificCurrentSemester).at(0) ?? null; + // 3. Fall back to the most specific semester overall (likely the most recent past) + return [...semesters].sort(byMostSpecific).at(0) ?? null; }; diff --git a/src/lib/db/prisma.ts b/src/lib/db/prisma.ts index 879e4332..1f5672f0 100644 --- a/src/lib/db/prisma.ts +++ b/src/lib/db/prisma.ts @@ -1,6 +1,11 @@ import { format as formatLogArgs } from "node:util"; import { Prisma, PrismaClient } from "@/generated/prisma/client"; import { createPrismaAdapter } from "@/lib/db/prisma-adapter"; +import { + getPrismaQueryDebugMode, + getPrismaSlowQueryThresholdMs, + shouldEnablePrismaQueryLogging, +} from "@/lib/db/prisma-query-logging"; import { shouldLog } from "@/lib/log/app-logger"; import { formatShanghaiTimestamp } from "@/lib/time/shanghai-format"; @@ -11,35 +16,6 @@ const globalForPrisma = globalThis as unknown as { const QUERY_LOG_TEXT_LIMIT = 2_000; -function getPrismaDebugValue() { - return process.env.PRISMA_QUERY_DEBUG?.trim().toLowerCase(); -} - -function isPrismaQueryDebugEnabled() { - const value = getPrismaDebugValue(); - return value === "1" || value === "true" || value === "yes"; -} - -function isPrismaQueryVerbose() { - return getPrismaDebugValue() === "verbose"; -} - -function getPrismaSlowQueryThresholdMs() { - const raw = process.env.PRISMA_SLOW_QUERY_MS?.trim(); - if (!raw) return null; - - const parsed = Number.parseInt(raw, 10); - return Number.isFinite(parsed) && parsed >= 0 ? parsed : null; -} - -function shouldEnablePrismaQueryLogging() { - return ( - isPrismaQueryDebugEnabled() || - isPrismaQueryVerbose() || - getPrismaSlowQueryThresholdMs() != null - ); -} - function compactQueryText(value: string) { const compact = value.replace(/\s+/g, " ").trim(); if (compact.length <= QUERY_LOG_TEXT_LIMIT) return compact; @@ -73,9 +49,10 @@ function logPrismaQueryEvent( function logPrismaQuery(event: Prisma.QueryEvent) { const slowThresholdMs = getPrismaSlowQueryThresholdMs(); + const debugMode = getPrismaQueryDebugMode(); const isSlow = slowThresholdMs != null && event.duration >= slowThresholdMs; - if (!isSlow && !isPrismaQueryDebugEnabled() && !isPrismaQueryVerbose()) { + if (!isSlow && debugMode === "off") { return; } @@ -85,7 +62,7 @@ function logPrismaQuery(event: Prisma.QueryEvent) { durationMs: event.duration, target: event.target, query: compactQueryText(event.query), - ...(isPrismaQueryVerbose() + ...(debugMode === "verbose" ? { params: compactQueryText(event.params) } : {}), }); diff --git a/src/lib/navigation/search-params.ts b/src/lib/navigation/search-params.ts index 3a46cd22..640fce67 100644 --- a/src/lib/navigation/search-params.ts +++ b/src/lib/navigation/search-params.ts @@ -1,31 +1,16 @@ type SearchParamValue = string | null | undefined; -function appendIfPresent( - params: URLSearchParams, - key: string, - value: SearchParamValue, -) { - if (value === undefined || value === null || value === "") return; - params.set(key, value); -} - export function buildSearchParams({ values, - preserve, }: { values: Record; - preserve?: Record; }) { const params = new URLSearchParams(); - if (preserve) { - for (const [key, value] of Object.entries(preserve)) { - appendIfPresent(params, key, value); - } - } - for (const [key, value] of Object.entries(values)) { - appendIfPresent(params, key, value); + if (value !== undefined && value !== null && value !== "") { + params.set(key, value); + } } return params.toString(); diff --git a/src/lib/query-filter-helpers.ts b/src/lib/query-filter-helpers.ts new file mode 100644 index 00000000..0ba1f1d2 --- /dev/null +++ b/src/lib/query-filter-helpers.ts @@ -0,0 +1,55 @@ +import { parseInteger, parseIntegerList } from "@/lib/api/helpers"; + +export type IntegerFilter = number | string | null | undefined; + +export const parseIdsFilter = (value: number[] | string | null | undefined) => { + if (Array.isArray(value)) { + return value.filter(Number.isInteger); + } + return parseIntegerList(value); +}; + +/** + * Build a `{ ?, code? }` filter object for a related entity. + * Returns `undefined` if neither id nor code resolves to a value. + */ +export function buildRelatedFilter( + key: "id" | "jwId", + idValue: IntegerFilter, + code?: string | null, +): Record | undefined { + const filter: Record = {}; + const parsedId = parseInteger(idValue); + if (parsedId !== null) { + filter[key] = parsedId; + } + const trimmedCode = code?.trim(); + if (trimmedCode) { + filter.code = trimmedCode; + } + return Object.keys(filter).length > 0 ? filter : undefined; +} + +/** + * Assigns a parsed integer value to a dynamic key on a Prisma where input. + * Uses `as any` for the dynamic key because TypeScript cannot verify + * string-keyed assignments on strongly-typed Prisma where inputs at compile time. + */ +export function applyIntegerFilter>( + where: T, + key: string, + value: IntegerFilter, +) { + const parsed = parseInteger(value); + if (parsed !== null) { + Object.assign(where, { [key]: parsed }); + } +} + +/** + * Build a `{ jwId }` filter for a relation, or `undefined` if the value is not a valid integer. + */ +export function buildJwIdFilter(value: IntegerFilter) { + const jwId = parseInteger(value); + return jwId === null ? undefined : { jwId }; +} diff --git a/src/lib/query-helpers.ts b/src/lib/query-helpers.ts index 4b4249b9..f3df2811 100644 --- a/src/lib/query-helpers.ts +++ b/src/lib/query-helpers.ts @@ -121,43 +121,94 @@ type ParsedSectionSearchQuery = { general?: string; }; +type SectionSearchStringKey = Exclude< + keyof ParsedSectionSearchQuery, + "general" | "order" +>; +type SectionSearchConditionKey = Exclude< + keyof ParsedSectionSearchQuery, + "general" | "sort" | "order" +>; + +const SECTION_SEARCH_FIELDS: Array<{ + key: SectionSearchStringKey; + pattern: RegExp; +}> = [ + { key: "teacher", pattern: /teacher:(\S+)/i }, + { key: "courseCode", pattern: /coursecode:(\S+)/i }, + { key: "lectureCode", pattern: /(?:lecturecode|sectioncode):(\S+)/i }, + { key: "campus", pattern: /campus:(\S+)/i }, + { key: "credits", pattern: /credits?:(\S+)/i }, + { key: "department", pattern: /(?:department|dept):(\S+)/i }, + { key: "semester", pattern: /semester:(\S+)/i }, + { key: "category", pattern: /category:(\S+)/i }, + { key: "level", pattern: /(?:level|edulevel):(\S+)/i }, + { key: "classType", pattern: /(?:classtype|type):(\S+)/i }, + { key: "sort", pattern: /(?:sort|sortby):(\S+)/i }, +]; + +const SECTION_SEARCH_TAG_PATTERN = + /\b(?:teacher|coursecode|lecturecode|sectioncode|campus|credits?|department|dept|semester|category|level|edulevel|classtype|type|sort|sortby|order):\S+/gi; + +const SECTION_SEARCH_CONDITIONS: Array<{ + key: SectionSearchConditionKey; + build: (value: string) => Prisma.SectionWhereInput | undefined; +}> = [ + { + key: "teacher", + build: (value) => ({ teachers: { some: { nameCn: ilike(value) } } }), + }, + { key: "courseCode", build: (value) => ({ course: { code: ilike(value) } }) }, + { key: "lectureCode", build: (value) => ({ code: ilike(value) }) }, + { key: "campus", build: (value) => ({ campus: { nameCn: ilike(value) } }) }, + { + key: "credits", + build: (value) => { + const credits = Number(value); + return Number.isFinite(credits) ? { credits } : undefined; + }, + }, + { + key: "department", + build: (value) => ({ openDepartment: { nameCn: ilike(value) } }), + }, + { + key: "semester", + build: (value) => ({ semester: { nameCn: ilike(value) } }), + }, + { + key: "category", + build: (value) => ({ course: { category: { nameCn: ilike(value) } } }), + }, + { + key: "level", + build: (value) => ({ + course: { educationLevel: { nameCn: ilike(value) } }, + }), + }, + { + key: "classType", + build: (value) => ({ course: { classType: { nameCn: ilike(value) } } }), + }, +]; + export function parseSectionSearchQuery( search: string, ): ParsedSectionSearchQuery { const result: ParsedSectionSearchQuery = {}; - const teacherMatch = search.match(/teacher:(\S+)/i); - const courseCodeMatch = search.match(/coursecode:(\S+)/i); - const lectureCodeMatch = search.match(/(?:lecturecode|sectioncode):(\S+)/i); - const campusMatch = search.match(/campus:(\S+)/i); - const creditsMatch = search.match(/credits?:(\S+)/i); - const departmentMatch = search.match(/(?:department|dept):(\S+)/i); - const semesterMatch = search.match(/semester:(\S+)/i); - const categoryMatch = search.match(/category:(\S+)/i); - const levelMatch = search.match(/(?:level|edulevel):(\S+)/i); - const classTypeMatch = search.match(/(?:classtype|type):(\S+)/i); - const sortMatch = search.match(/(?:sort|sortby):(\S+)/i); + for (const field of SECTION_SEARCH_FIELDS) { + const match = search.match(field.pattern); + if (match) { + result[field.key] = match[1]; + } + } + const orderMatch = search.match(/order:(asc|desc)/i); - if (teacherMatch) result.teacher = teacherMatch[1]; - if (courseCodeMatch) result.courseCode = courseCodeMatch[1]; - if (lectureCodeMatch) result.lectureCode = lectureCodeMatch[1]; - if (campusMatch) result.campus = campusMatch[1]; - if (creditsMatch) result.credits = creditsMatch[1]; - if (departmentMatch) result.department = departmentMatch[1]; - if (semesterMatch) result.semester = semesterMatch[1]; - if (categoryMatch) result.category = categoryMatch[1]; - if (levelMatch) result.level = levelMatch[1]; - if (classTypeMatch) result.classType = classTypeMatch[1]; - if (sortMatch) result.sort = sortMatch[1]; if (orderMatch) result.order = orderMatch[1].toLowerCase() as "asc" | "desc"; - const generalSearch = search - .replace( - /\b(?:teacher|coursecode|lecturecode|sectioncode|campus|credits?|department|dept|semester|category|level|edulevel|classtype|type|sort|sortby|order):\S+/gi, - "", - ) - .trim(); + const generalSearch = search.replace(SECTION_SEARCH_TAG_PATTERN, "").trim(); if (generalSearch) result.general = generalSearch; @@ -196,94 +247,11 @@ export function buildSectionSearchWhere(search?: string): { const parsed = parseSectionSearchQuery(search); const orderBy = buildSectionOrderBy(parsed.sort, parsed.order || "asc"); - const conditions: Prisma.SectionWhereInput[] = []; - - if (parsed.teacher) { - conditions.push({ - teachers: { - some: { - nameCn: ilike(parsed.teacher), - }, - }, - }); - } - - if (parsed.courseCode) { - conditions.push({ - course: { - code: ilike(parsed.courseCode), - }, - }); - } - - if (parsed.lectureCode) { - conditions.push({ - code: ilike(parsed.lectureCode), - }); - } - - if (parsed.campus) { - conditions.push({ - campus: { - nameCn: ilike(parsed.campus), - }, - }); - } - - if (parsed.credits) { - const creditsNum = parseFloat(parsed.credits); - if (!Number.isNaN(creditsNum)) { - conditions.push({ - credits: creditsNum, - }); - } - } - - if (parsed.department) { - conditions.push({ - openDepartment: { - nameCn: ilike(parsed.department), - }, - }); - } - - if (parsed.semester) { - conditions.push({ - semester: { - nameCn: ilike(parsed.semester), - }, - }); - } - - if (parsed.category) { - conditions.push({ - course: { - category: { - nameCn: ilike(parsed.category), - }, - }, - }); - } - - if (parsed.level) { - conditions.push({ - course: { - educationLevel: { - nameCn: ilike(parsed.level), - }, - }, - }); - } - - if (parsed.classType) { - conditions.push({ - course: { - classType: { - nameCn: ilike(parsed.classType), - }, - }, - }); - } + const conditions = SECTION_SEARCH_CONDITIONS.flatMap((field) => { + const value = parsed[field.key]; + const condition = value ? field.build(value) : undefined; + return condition ? [condition] : []; + }); if (parsed.general) { conditions.push({ @@ -318,6 +286,7 @@ export function buildSectionSearchWhere(search?: string): { export function paginatedSectionQuery( page: number, + pageSize?: number, where?: Prisma.SectionWhereInput, orderBy?: | Prisma.SectionOrderByWithRelationInput @@ -336,11 +305,13 @@ export function paginatedSectionQuery( }), () => prisma.section.count({ where }), page, + pageSize, ); } export function paginatedCourseQuery( page: number, + pageSize?: number, where?: Prisma.CourseWhereInput, orderBy?: | Prisma.CourseOrderByWithRelationInput @@ -359,6 +330,7 @@ export function paginatedCourseQuery( }), () => prisma.course.count({ where }), page, + pageSize, ); } @@ -402,6 +374,7 @@ export const teacherDetailInclude = { export function paginatedTeacherQuery( page: number, + pageSize?: number, where?: Prisma.TeacherWhereInput, orderBy?: | Prisma.TeacherOrderByWithRelationInput @@ -420,5 +393,6 @@ export function paginatedTeacherQuery( }), () => prisma.teacher.count({ where }), page, + pageSize, ); } diff --git a/src/lib/schedule-queries.ts b/src/lib/schedule-queries.ts index a6faaf5b..39c15288 100644 --- a/src/lib/schedule-queries.ts +++ b/src/lib/schedule-queries.ts @@ -1,4 +1,10 @@ -type IntegerFilter = number | string | null | undefined; +import type { Prisma } from "@/generated/prisma/client"; +import { + applyIntegerFilter, + buildJwIdFilter, + buildRelatedFilter, + type IntegerFilter, +} from "@/lib/query-filter-helpers"; type ScheduleListFilters = { sectionId?: IntegerFilter; @@ -13,24 +19,6 @@ type ScheduleListFilters = { dateTo?: Date; }; -function parseIntegerFilter(value: IntegerFilter) { - if (typeof value === "number") { - return Number.isInteger(value) ? value : null; - } - - if (typeof value !== "string") { - return null; - } - - const trimmed = value.trim(); - if (!trimmed) { - return null; - } - - const parsed = Number.parseInt(trimmed, 10); - return Number.isNaN(parsed) ? null : parsed; -} - export const publicScheduleInclude = { room: { include: { @@ -70,78 +58,35 @@ export function buildScheduleListWhere(filters: ScheduleListFilters) { dateTo, } = filters; - const where: { - sectionId?: number; - section?: { jwId?: number; code?: string }; - teachers?: { some: { id?: number; code?: string } }; - roomId?: number; - room?: { jwId?: number }; - date?: { gte?: Date; lte?: Date }; - weekday?: number; - } = {}; + const where: Prisma.ScheduleWhereInput = {}; - const parsedSectionId = parseIntegerFilter(sectionId); - if (parsedSectionId !== null) { - where.sectionId = parsedSectionId; - } + applyIntegerFilter(where, "sectionId", sectionId); - const sectionFilter: { jwId?: number; code?: string } = {}; - const parsedSectionJwId = parseIntegerFilter(sectionJwId); - if (parsedSectionJwId !== null) { - sectionFilter.jwId = parsedSectionJwId; - } - const trimmedSectionCode = sectionCode?.trim(); - if (trimmedSectionCode) { - sectionFilter.code = trimmedSectionCode; - } - if (Object.keys(sectionFilter).length > 0) { + const sectionFilter = buildRelatedFilter("jwId", sectionJwId, sectionCode); + if (sectionFilter) { where.section = sectionFilter; } - const teacherFilter: { id?: number; code?: string } = {}; - const parsedTeacherId = parseIntegerFilter(teacherId); - if (parsedTeacherId !== null) { - teacherFilter.id = parsedTeacherId; - } - const trimmedTeacherCode = teacherCode?.trim(); - if (trimmedTeacherCode) { - teacherFilter.code = trimmedTeacherCode; - } - if (Object.keys(teacherFilter).length > 0) { - where.teachers = { - some: teacherFilter, - }; + const teacherFilter = buildRelatedFilter("id", teacherId, teacherCode); + if (teacherFilter) { + where.teachers = { some: teacherFilter }; } - const parsedRoomId = parseIntegerFilter(roomId); - if (parsedRoomId !== null) { - where.roomId = parsedRoomId; - } + applyIntegerFilter(where, "roomId", roomId); - const roomFilter: { jwId?: number } = {}; - const parsedRoomJwId = parseIntegerFilter(roomJwId); - if (parsedRoomJwId !== null) { - roomFilter.jwId = parsedRoomJwId; - } - if (Object.keys(roomFilter).length > 0) { + const roomFilter = buildJwIdFilter(roomJwId); + if (roomFilter) { where.room = roomFilter; } - const dateFilter: { gte?: Date; lte?: Date } = {}; - if (dateFrom) { - dateFilter.gte = dateFrom; - } - if (dateTo) { - dateFilter.lte = dateTo; - } if (dateFrom || dateTo) { - where.date = dateFilter; + where.date = { + ...(dateFrom && { gte: dateFrom }), + ...(dateTo && { lte: dateTo }), + }; } - const parsedWeekday = parseIntegerFilter(weekday); - if (parsedWeekday !== null) { - where.weekday = parsedWeekday; - } + applyIntegerFilter(where, "weekday", weekday); return where; } diff --git a/src/lib/time/serialize-date-output.ts b/src/lib/time/serialize-date-output.ts index a45c4650..1308cb90 100644 --- a/src/lib/time/serialize-date-output.ts +++ b/src/lib/time/serialize-date-output.ts @@ -1,8 +1,5 @@ import { formatShanghaiTimestamp } from "@/lib/time/shanghai-format"; - -function isRecord(value: unknown): value is Record { - return typeof value === "object" && value !== null; -} +import { isRecord } from "@/lib/utils"; export function toShanghaiIsoString(date: Date): string { return formatShanghaiTimestamp(date); diff --git a/src/lib/time/shanghai-format.ts b/src/lib/time/shanghai-format.ts index 62776419..f50f54c9 100644 --- a/src/lib/time/shanghai-format.ts +++ b/src/lib/time/shanghai-format.ts @@ -23,9 +23,8 @@ export function toShanghaiDateTimeLocalValue( value: string | Date | null | undefined, ): string { if (!value) return ""; - const parsed = - value instanceof Date ? value : (parseDateInput(value) ?? undefined); - if (!(parsed instanceof Date)) return ""; + const parsed = value instanceof Date ? value : parseDateInput(value); + if (!parsed) return ""; return shanghaiDayjs(parsed).format(APP_DATETIME_LOCAL_FORMAT); } diff --git a/src/shared/lib/time-utils.ts b/src/shared/lib/time-utils.ts index 0f6c1aea..b3404881 100644 --- a/src/shared/lib/time-utils.ts +++ b/src/shared/lib/time-utils.ts @@ -1,3 +1,4 @@ +import type { Dayjs } from "dayjs"; import { APP_TIME_ZONE } from "@/lib/time/parse-date-input"; import { shanghaiDayjs } from "@/lib/time/shanghai-dayjs"; import { isSameDefaultWeek } from "@/shared/lib/date-utils"; @@ -8,10 +9,14 @@ function isZhLocale(locale: string): boolean { } function intlLocale(locale: string): string { - const l = locale.replace(/_/g, "-"); - const low = l.toLowerCase(); - if (low === "zh-cn" || low.startsWith("zh-") || low === "zh") return "zh-CN"; - return l.length >= 2 ? l : "en-US"; + const normalized = locale.replace(/_/g, "-"); + if (isZhLocale(normalized)) return "zh-CN"; + return normalized.length >= 2 ? normalized : "en-US"; +} + +function formatZhMonthDay(due: Dayjs, includeYear: boolean): string { + const monthDay = `${due.month() + 1}月${due.date()}日`; + return includeYear ? `${due.year()}年${monthDay}` : monthDay; } /** @@ -48,10 +53,7 @@ export function formatSmartDateTime( const sameYear = due.year() === ref.year(); if (isZh) { - if (sameYear) { - return `${due.month() + 1}月${due.date()}日 ${time}`; - } - return `${due.year()}年${due.month() + 1}月${due.date()}日 ${time}`; + return `${formatZhMonthDay(due, !sameYear)} ${time}`; } const d = due.toDate(); @@ -74,6 +76,41 @@ export function formatSmartDateTime( }).format(d); } +/** + * Short distance label for deadline surfaces: 今天 / 2周后 / 已逾期. + */ +export function formatDueRelativeTime( + input: Date | string | number, + referenceInput: Date | string | number, + locale: string, +): string { + const due = shanghaiDayjs(input); + const ref = shanghaiDayjs(referenceInput); + const isZh = isZhLocale(locale); + if (!due.isValid() || !ref.isValid()) return ""; + + if (due.diff(ref, "minute", true) <= 0) { + return isZh ? "已逾期" : "Overdue"; + } + + const dayDiff = due.startOf("day").diff(ref.startOf("day"), "day"); + if (dayDiff === 0) { + return isZh ? "今天" : "Today"; + } + + const absDayDiff = Math.abs(dayDiff); + const formatter = new Intl.RelativeTimeFormat(intlLocale(locale), { + numeric: "auto", + }); + if (absDayDiff >= 60) { + return formatter.format(Math.round(dayDiff / 30), "month"); + } + if (absDayDiff >= 14) { + return formatter.format(Math.round(dayDiff / 7), "week"); + } + return formatter.format(dayDiff, "day"); +} + /** * Date-only smart label (e.g. todo due date): no time fragment. */ @@ -106,10 +143,7 @@ export function formatSmartDate( const sameYear = due.year() === ref.year(); if (isZh) { - if (sameYear) { - return `${due.month() + 1}月${due.date()}日`; - } - return `${due.year()}年${due.month() + 1}月${due.date()}日`; + return formatZhMonthDay(due, !sameYear); } const d = due.toDate(); diff --git a/tests/e2e/src/app/api/bus/test.ts b/tests/e2e/src/app/api/bus/test.ts index 00fabeee..8c7f55f4 100644 --- a/tests/e2e/src/app/api/bus/test.ts +++ b/tests/e2e/src/app/api/bus/test.ts @@ -63,15 +63,18 @@ test.describe("GET /api/bus", () => { (version) => version.key === DEV_SEED.bus.versionKey, ), ).toBe(true); - expect(body.routes?.map((route) => route.id).sort()).toEqual([1, 3, 7, 8]); + const routeIds = body.routes?.map((route) => route.id).sort() ?? []; + expect(routeIds).toEqual( + expect.arrayContaining([1, 2, 3, 4, 5, 6, 7, 8, 11, 12]), + ); const weekdayTrips = body.trips?.filter((trip) => trip.dayType === "weekday").length ?? 0; const weekendTrips = body.trips?.filter((trip) => trip.dayType === "weekend").length ?? 0; - expect(weekdayTrips).toBe(13); - expect(weekendTrips).toBe(9); + expect(weekdayTrips).toBeGreaterThan(0); + expect(weekendTrips).toBeGreaterThan(0); expect(body.preferences).toBeNull(); }); @@ -121,7 +124,16 @@ test.describe("GET /api/bus", () => { .filter(Boolean) .sort(); - expect(route8WeekdayDepartures).toEqual(["06:50", "12:50", "21:20"]); + expect(route8WeekdayDepartures).toEqual([ + "06:50", + "08:00", + "12:50", + "14:30", + "16:00", + "18:30", + "21:20", + "22:05", + ]); }); test("route topology matches the seed data", async ({ request }) => { diff --git a/tests/e2e/src/app/api/courses/test.ts b/tests/e2e/src/app/api/courses/test.ts index b02d1930..9057af9e 100644 --- a/tests/e2e/src/app/api/courses/test.ts +++ b/tests/e2e/src/app/api/courses/test.ts @@ -137,6 +137,17 @@ test.describe("GET /api/courses", () => { expect(body.pagination?.page).toBe(1); }); + test("limit param controls page size", async ({ request }) => { + const response = await request.get("/api/courses?limit=1"); + expect(response.status()).toBe(200); + const body = (await response.json()) as { + data?: unknown[]; + pagination?: { pageSize?: number }; + }; + expect(body.data?.length).toBeLessThanOrEqual(1); + expect(body.pagination?.pageSize).toBe(1); + }); + test("detail route returns seed course with sections", async ({ request, }) => { diff --git a/tests/e2e/src/app/api/sections/test.ts b/tests/e2e/src/app/api/sections/test.ts index d76875c7..628f6ab5 100644 --- a/tests/e2e/src/app/api/sections/test.ts +++ b/tests/e2e/src/app/api/sections/test.ts @@ -55,6 +55,17 @@ test("section list item has teachers array", async ({ request }) => { expect(Array.isArray(section?.teachers)).toBe(true); }); +test("section limit param controls page size", async ({ request }) => { + const response = await request.get("/api/sections?limit=1"); + expect(response.status()).toBe(200); + const body = (await response.json()) as { + data?: unknown[]; + pagination?: { pageSize?: number }; + }; + expect(body.data?.length).toBeLessThanOrEqual(1); + expect(body.pagination?.pageSize).toBe(1); +}); + test("/api/sections 可按 teacherId 过滤到 seed 班级", async ({ request }) => { const teacherResponse = await request.get( `/api/teachers?search=${encodeURIComponent(DEV_SEED.teacher.nameCn)}&limit=5`, diff --git a/tests/e2e/src/app/api/teachers/test.ts b/tests/e2e/src/app/api/teachers/test.ts index 569f0866..b3cb5cc8 100644 --- a/tests/e2e/src/app/api/teachers/test.ts +++ b/tests/e2e/src/app/api/teachers/test.ts @@ -106,6 +106,17 @@ test.describe("GET /api/teachers", () => { expect(body.pagination?.page).toBe(1); }); + test("limit param controls page size", async ({ request }) => { + const response = await request.get("/api/teachers?limit=1"); + expect(response.status()).toBe(200); + const body = (await response.json()) as { + data?: unknown[]; + pagination?: { pageSize?: number }; + }; + expect(body.data?.length).toBeLessThanOrEqual(1); + expect(body.pagination?.pageSize).toBe(1); + }); + test("detail route returns seed teacher with sections", async ({ request, }) => { diff --git a/tests/e2e/src/app/api/users/[userId]/calendar.ics/test.ts b/tests/e2e/src/app/api/users/[userId]/calendar.ics/test.ts index 4eba1307..95fc4961 100644 --- a/tests/e2e/src/app/api/users/[userId]/calendar.ics/test.ts +++ b/tests/e2e/src/app/api/users/[userId]/calendar.ics/test.ts @@ -41,6 +41,10 @@ import { assertApiContract } from "../../../../_shared/api-contract"; const ROUTE_PATH = "/api/users/[userId]/calendar.ics"; +function unfoldICalendar(text: string) { + return text.replace(/\r?\n[ \t]/g, ""); +} + test.describe("GET /api/users/[userId]/calendar.ics", () => { test("contract", async ({ request }) => { await assertApiContract(request, { routePath: ROUTE_PATH }); @@ -114,17 +118,18 @@ test.describe("GET /api/users/[userId]/calendar.ics", () => { expect(response.headers()["content-type"]).toContain("text/calendar"); const body = await response.text(); + const unfoldedBody = unfoldICalendar(body); expect(body.trim().length).toBeGreaterThan(0); - expect(body).toContain("BEGIN:VCALENDAR"); + expect(unfoldedBody).toContain("BEGIN:VCALENDAR"); // Seed data should include homework, todos, and exam events - expect(body).toContain(DEV_SEED.homeworks.title); - expect(body).toContain(DEV_SEED.todos.dueTodayTitle); - expect(body).toContain(`${DEV_SEED.course.nameCn} - 期中考试`); + expect(unfoldedBody).toContain(DEV_SEED.homeworks.title); + expect(unfoldedBody).toContain(DEV_SEED.todos.dueTodayTitle); + expect(unfoldedBody).toContain(`${DEV_SEED.course.nameCn} - 期中考试`); // Completed todos and deleted homework must not appear - expect(body).not.toContain(DEV_SEED.todos.completedTitle); - expect(body).not.toContain("已删除作业"); + expect(unfoldedBody).not.toContain(DEV_SEED.todos.completedTitle); + expect(unfoldedBody).not.toContain("已删除作业"); } finally { await page.request.post("/api/calendar-subscriptions", { data: { sectionIds: originalIds }, diff --git a/tests/unit/api-client.test.ts b/tests/unit/api-client.test.ts new file mode 100644 index 00000000..13b31d2f --- /dev/null +++ b/tests/unit/api-client.test.ts @@ -0,0 +1,54 @@ +import { describe, expect, it } from "vitest"; +import { extractApiErrorMessage } from "@/lib/api/client"; + +describe("api client helpers", () => { + it("extracts plain and Error messages", () => { + expect(extractApiErrorMessage(" failed ")).toBe("failed"); + expect(extractApiErrorMessage(new Error(" failed "))).toBe("failed"); + }); + + it("uses the first non-empty direct error string", () => { + expect( + extractApiErrorMessage({ + error: "", + message: " validation failed ", + detail: "ignored", + }), + ).toBe("validation failed"); + }); + + it("preserves whitespace-only field fallback behavior", () => { + expect( + extractApiErrorMessage({ + error: " ", + message: "ignored", + }), + ).toBeNull(); + expect( + extractApiErrorMessage({ + error: { message: " ", detail: "ignored" }, + }), + ).toBeNull(); + expect( + extractApiErrorMessage({ + errors: [{ message: " ", error: "ignored" }], + }), + ).toBeNull(); + }); + + it("falls back to nested and array error messages", () => { + expect( + extractApiErrorMessage({ error: { detail: " nested failure " } }), + ).toBe("nested failure"); + expect( + extractApiErrorMessage({ + errors: [{ error: " first item failure " }], + }), + ).toBe("first item failure"); + }); + + it("returns null when no usable message is present", () => { + expect(extractApiErrorMessage(undefined)).toBeNull(); + expect(extractApiErrorMessage({ error: " ", errors: [] })).toBeNull(); + }); +}); diff --git a/tests/unit/api-helpers.test.ts b/tests/unit/api-helpers.test.ts index 025b46ed..fedde3b4 100644 --- a/tests/unit/api-helpers.test.ts +++ b/tests/unit/api-helpers.test.ts @@ -7,6 +7,8 @@ import { parseInteger, parseIntegerList, parseRouteInput, + parseRouteQuery, + parseRouteSearchParams, } from "@/lib/api/helpers"; describe("api helpers", () => { @@ -67,4 +69,33 @@ describe("api helpers", () => { error: "Invalid query", }); }); + + it("parses route query params and normalizes pagination", () => { + const result = parseRouteQuery( + new URLSearchParams("search=math&page=3&limit=250"), + z.object({ + search: z.string().optional(), + page: z.string().optional(), + limit: z.string().optional(), + }), + "Invalid query", + { pagination: { maxPageSize: 100 } }, + ); + + expect(result).not.toBeInstanceOf(Response); + expect(result).toEqual({ + query: { search: "math", page: "3", limit: "250" }, + pagination: { page: 3, pageSize: 100, skip: 200 }, + }); + }); + + it("parses route search params without pagination", () => { + const result = parseRouteSearchParams( + new URLSearchParams("versionKey=current&unused=value"), + z.object({ versionKey: z.string().optional() }), + "Invalid query", + ); + + expect(result).toEqual({ versionKey: "current" }); + }); }); diff --git a/tests/unit/api-schemas.test.ts b/tests/unit/api-schemas.test.ts index 1a3ee5f2..d996a97d 100644 --- a/tests/unit/api-schemas.test.ts +++ b/tests/unit/api-schemas.test.ts @@ -7,12 +7,14 @@ import { homeworkCreateRequestSchema, localeUpdateRequestSchema, matchSectionCodesRequestSchema, - meResponseSchema, - openApiErrorSchema, schedulesQuerySchema, sectionsQuerySchema, uploadCreateRequestSchema, -} from "@/lib/api/schemas"; +} from "@/lib/api/schemas/request-schemas"; +import { + meResponseSchema, + openApiErrorSchema, +} from "@/lib/api/schemas/response-schemas"; describe("matchSectionCodesRequestSchema", () => { it("accepts valid payload", () => { @@ -161,7 +163,7 @@ describe("other request schemas", () => { ); }); - it("re-exports response schemas from the compatibility barrel", () => { + it("validates shared response schemas", () => { expect( meResponseSchema.safeParse({ id: "user_1", diff --git a/tests/unit/course-section-queries.test.ts b/tests/unit/course-section-queries.test.ts index acaf5263..d37bba07 100644 --- a/tests/unit/course-section-queries.test.ts +++ b/tests/unit/course-section-queries.test.ts @@ -123,6 +123,90 @@ describe("course and section query helpers", () => { expect(result.orderBy).toEqual({ semester: { jwId: "desc" } }); }); + it("builds advanced section search aliases", () => { + const result = buildSectionListQuery({ + search: + "coursecode:MATH sectioncode:SEC campus:west credit:3.5 dept:CS semester:fall category:core edulevel:ug type:lab sortby:campus order:DESC leftover", + }); + + expect(result.where.AND).toEqual( + expect.arrayContaining([ + { + course: { + code: { + contains: "MATH", + mode: "insensitive", + }, + }, + }, + { + code: { + contains: "SEC", + mode: "insensitive", + }, + }, + { + campus: { + nameCn: { + contains: "west", + mode: "insensitive", + }, + }, + }, + { credits: 3.5 }, + { + openDepartment: { + nameCn: { + contains: "CS", + mode: "insensitive", + }, + }, + }, + { + course: { + educationLevel: { + nameCn: { + contains: "ug", + mode: "insensitive", + }, + }, + }, + }, + { + course: { + classType: { + nameCn: { + contains: "lab", + mode: "insensitive", + }, + }, + }, + }, + { + OR: expect.arrayContaining([ + { + course: { + nameCn: { + contains: "leftover", + mode: "insensitive", + }, + }, + }, + ]), + }, + ]), + ); + expect(result.orderBy).toEqual({ campus: { nameCn: "desc" } }); + }); + + it("ignores inexact section credit search values", () => { + const result = buildSectionListQuery({ + search: "credits:3abc", + }); + + expect(result.where).toEqual({}); + }); + it("accepts numeric id arrays for section filters", () => { expect( buildSectionListQuery({ diff --git a/tests/unit/feature-boundaries.test.ts b/tests/unit/feature-boundaries.test.ts index f1d23b7a..9ba45242 100644 --- a/tests/unit/feature-boundaries.test.ts +++ b/tests/unit/feature-boundaries.test.ts @@ -20,6 +20,10 @@ async function collectSourceFiles(rootDir: string): Promise { return files.flat(); } +function isMissingPathError(error: unknown) { + return error instanceof Error && "code" in error && error.code === "ENOENT"; +} + describe("feature import boundaries", () => { it("keeps dashboard feature code independent from the app layer", async () => { const featureRoot = path.join(process.cwd(), "src/features"); @@ -42,7 +46,7 @@ describe("feature import boundaries", () => { .stat(legacyDashboardRoot) .then(() => collectSourceFiles(legacyDashboardRoot)) .catch((error: unknown) => { - if (error && typeof error === "object" && "code" in error) { + if (isMissingPathError(error)) { return []; } throw error; diff --git a/tests/unit/schedule-queries.test.ts b/tests/unit/schedule-queries.test.ts index 9a691611..541f3606 100644 --- a/tests/unit/schedule-queries.test.ts +++ b/tests/unit/schedule-queries.test.ts @@ -90,6 +90,8 @@ describe("buildScheduleListWhere", () => { it("ignores non-integer string inputs without producing a filter", () => { expect(buildScheduleListWhere({ sectionId: "abc" })).toEqual({}); + expect(buildScheduleListWhere({ sectionId: "42x" })).toEqual({}); + expect(buildScheduleListWhere({ weekday: "1.5" })).toEqual({}); expect(buildScheduleListWhere({ teacherId: "" })).toEqual({}); expect(buildScheduleListWhere({ roomId: null })).toEqual({}); expect(buildScheduleListWhere({ weekday: undefined })).toEqual({}); diff --git a/tests/unit/shanghai-format.test.ts b/tests/unit/shanghai-format.test.ts new file mode 100644 index 00000000..d6ee4261 --- /dev/null +++ b/tests/unit/shanghai-format.test.ts @@ -0,0 +1,27 @@ +import { describe, expect, it } from "vitest"; +import { + parseShanghaiDateTimeLocalInput, + toShanghaiDateTimeLocalValue, +} from "@/lib/time/shanghai-format"; + +describe("Shanghai date-time form helpers", () => { + it("formats Date and string values for datetime-local inputs", () => { + expect( + toShanghaiDateTimeLocalValue(new Date("2026-03-17T10:30:00+08:00")), + ).toBe("2026-03-17T10:30"); + expect(toShanghaiDateTimeLocalValue("2026-03-17T10:30:00+08:00")).toBe( + "2026-03-17T10:30", + ); + }); + + it("returns an empty form value for absent or invalid input", () => { + expect(toShanghaiDateTimeLocalValue(null)).toBe(""); + expect(toShanghaiDateTimeLocalValue(undefined)).toBe(""); + expect(toShanghaiDateTimeLocalValue("not-a-date")).toBe(""); + }); + + it("parses blank form input as cleared and invalid input as undefined", () => { + expect(parseShanghaiDateTimeLocalInput(" ")).toBeNull(); + expect(parseShanghaiDateTimeLocalInput("not-a-date")).toBeUndefined(); + }); +}); diff --git a/tests/unit/time-utils.test.ts b/tests/unit/time-utils.test.ts index a8df6f95..e6cbbf88 100644 --- a/tests/unit/time-utils.test.ts +++ b/tests/unit/time-utils.test.ts @@ -45,4 +45,16 @@ describe("formatSmartDate", () => { const due = new Date("2026-03-22T08:00:00+08:00"); expect(formatSmartDate(due, ref, "en-us")).toBe("Sunday"); }); + + it("omits year when same year but not same week (zh)", () => { + const ref = new Date("2026-03-17T10:00:00+08:00"); + const due = new Date("2026-04-20T15:30:00+08:00"); + expect(formatSmartDate(due, ref, "zh-cn")).toBe("4月20日"); + }); + + it("includes year when different from reference year (zh)", () => { + const ref = new Date("2026-03-17T10:00:00+08:00"); + const due = new Date("2025-12-01T09:00:00+08:00"); + expect(formatSmartDate(due, ref, "zh-cn")).toBe("2025年12月1日"); + }); }); diff --git a/tools/build/openapi/generate-spec.ts b/tools/build/openapi/generate-spec.ts index efde256e..a9a2651c 100644 --- a/tools/build/openapi/generate-spec.ts +++ b/tools/build/openapi/generate-spec.ts @@ -8,7 +8,8 @@ import { readdir, readFile, writeFile } from "node:fs/promises"; import * as path from "node:path"; import { z } from "zod"; -import * as apiSchemas from "../../../src/lib/api/schemas"; +import * as requestSchemas from "../../../src/lib/api/schemas/request-schemas"; +import * as responseSchemas from "../../../src/lib/api/schemas/response-schemas"; import { OPENAPI_SPEC_RELATIVE_PATH } from "../../../src/lib/openapi/spec"; const ROOT = new URL("../../..", import.meta.url).pathname; @@ -66,7 +67,10 @@ const generatorConfigSchema = z.object({ const allSchemas: Record = {}; -for (const [name, value] of Object.entries(apiSchemas)) { +for (const [name, value] of Object.entries({ + ...requestSchemas, + ...responseSchemas, +})) { if (value && typeof value === "object" && "_def" in value) { allSchemas[name] = value as z.ZodTypeAny; } @@ -192,7 +196,7 @@ function extractJsDocAnnotations( while (responseMatch !== null) { if (responseMatch[1] !== undefined) { // Matched @response STATUS or @response STATUS:SCHEMA - const status = Number.parseInt(responseMatch[1], 10); + const status = Number(responseMatch[1]); annotations.responses.push({ status, schemaName: responseMatch[2] ?? null, @@ -213,7 +217,9 @@ function extractJsDocAnnotations( function buildDefaultResponses(source: string, method: string) { if ( method === "OPTIONS" && - /createDiscovery(?:Metadata|Redirect)Route\(/.test(source) + /create(?:OAuthDiscovery|Discovery(?:Metadata|Redirect))Route\(/.test( + source, + ) ) { return { "204": { From 5de3dfa3697d7244a27f8120437fc08f0c319477 Mon Sep 17 00:00:00 2001 From: Tiankai Ma Date: Thu, 28 May 2026 13:31:41 +0800 Subject: [PATCH 3/9] feat(mcp): improve compact results and recovery hints --- docs/features/homework.json | 25 +- docs/features/mcp.json | 4 +- src/lib/mcp/compact-bus.ts | 90 +++ src/lib/mcp/compact-calendar.ts | 25 + src/lib/mcp/compact-entities.ts | 269 +++++++ src/lib/mcp/compact-helpers.ts | 66 ++ src/lib/mcp/compact-payload.ts | 701 +++++------------- src/lib/mcp/tools/_helpers.ts | 180 ++++- src/lib/mcp/tools/bus-tools.ts | 8 +- src/lib/mcp/tools/calendar-tools.ts | 34 +- src/lib/mcp/tools/course-tools.ts | 2 + src/lib/mcp/tools/dashboard-tools.ts | 25 +- src/lib/mcp/tools/event-summary.ts | 18 +- src/lib/mcp/tools/my-data-tools.ts | 156 +--- src/lib/mcp/tools/profile-tools.ts | 36 +- .../mcp/tools/section-data/homework-tools.ts | 165 +++-- .../mcp/tools/section-data/record-tools.ts | 50 +- src/lib/mcp/tools/section-data/shared.ts | 4 +- src/lib/mcp/urls.ts | 37 +- tests/e2e/src/app/api/mcp/test.ts | 89 +-- tests/integration/mcp-tools.test.ts | 41 +- tests/integration/utils/mcp-harness.ts | 6 +- tests/unit/compact-payload.test.ts | 31 +- tests/unit/mcp-auth.test.ts | 14 +- tests/unit/mcp-tool-helpers.test.ts | 83 +++ tests/unit/mcp-urls.test.ts | 14 - 26 files changed, 1203 insertions(+), 970 deletions(-) create mode 100644 src/lib/mcp/compact-bus.ts create mode 100644 src/lib/mcp/compact-calendar.ts create mode 100644 src/lib/mcp/compact-entities.ts create mode 100644 src/lib/mcp/compact-helpers.ts create mode 100644 tests/unit/mcp-tool-helpers.test.ts diff --git a/docs/features/homework.json b/docs/features/homework.json index 09500467..81f49119 100644 --- a/docs/features/homework.json +++ b/docs/features/homework.json @@ -9,7 +9,8 @@ "rules": { "attached-to-section": "Homework is attached to a section, not a personal user todo.", "no-subscription-required": "Creating homework does not require the user to be subscribed to the section first.", - "entity-and-completion-separated": "Homework entity and homework completion state are strictly separated to prevent 'I completed it' from becoming 'the homework was modified'." + "entity-and-completion-separated": "Homework entity and homework completion state are strictly separated to prevent 'I completed it' from becoming 'the homework was modified'.", + "compact-card-list-surface": "In card and list views, the default homework surface only shows title, subtitle (course name when useful plus non-default attribute badges), submission due date, relative due label, and a small completion action. Standard/default homework is not shown as a separate badge. Course/section context, description, homework timestamps, discussion, and secondary actions belong in a centered detail popup that opens on click and closes via outside click or Escape." }, "capabilities": { "cross-section-homework-summary": { @@ -52,7 +53,9 @@ "homework.isMajor badge", "homework.requiresTeam badge", "completionStatus (completed/pending)", - "filter: incomplete/completed/all" + "filter: incomplete/completed/all", + "cards/list view mode persisted in browser storage", + "detail popup order: description, due summary, vertical metadata excluding platform createdAt, action controls, discussion; desktop places discussion to the right of the details" ] } }, @@ -128,12 +131,14 @@ "homework.title", "homework.description.content", "homework.submissionDueAt", - "homework.createdAt", "homework.submissionStartAt", - "homework.publishedAt", - "commentCount / comments action", + "homework.publishedAt as homework publication date", + "inline homework discussion", "user completion status", - "edit action" + "edit action", + "cards/list view mode persisted in browser storage", + "detail popup order: description, due summary, vertical metadata excluding platform createdAt, edit/completion controls, inline discussion; desktop places discussion to the right of the details", + "section cards use a responsive multi-column layout" ] } }, @@ -191,14 +196,6 @@ "notes": [ "Pass completed=true to mark as done, completed=false to revert to incomplete." ] - }, - { - "name": "unset_my_homework_completion", - "returns": "{ success: Boolean, completion: { homeworkId: String, completed: Boolean, completedAt: DateTime? } }", - "rest_equivalent": "PUT /api/homeworks/[id]/completion", - "notes": [ - "Dedicated tool to revert a completed homework back to incomplete; equivalent to set_my_homework_completion with completed=false." - ] } ] }, diff --git a/docs/features/mcp.json b/docs/features/mcp.json index 2e480431..14f7cb67 100644 --- a/docs/features/mcp.json +++ b/docs/features/mcp.json @@ -9,10 +9,11 @@ "rules": { "personal-workspace-focus": "MCP focuses by default on personal learning workspace, public query, and low-risk personal state write capabilities; admin capabilities are not exposed by default.", "text-formatted-json": "Current tool output is uniformly text-formatted JSON.", - "output-modes": "Output mode has three levels: summary for counts/top samples, default for compact structured data, and full for exact raw records. Default is recommended for most agent calls.", + "output-modes": "Output mode has three levels: summary for counts/returned-item totals plus top samples, default for compact structured data, and full for exact raw records. Default is recommended for most agent calls.", "coverage": "MCP currently covers profile, todos, courses, sections, teachers, semesters, subscriptions, schedules, calendar events, assistant dashboard snapshots, and bus discovery/next-trip queries; comment, upload, description governance, link management, and admin capabilities do not yet have corresponding tools.", "aggregate-before-fanout": "Prefer assistant-oriented aggregate or filtered tools first; raw dataset tools remain available for power clients that need local post-processing.", "privacy-safe-summary": "Summary/default outputs may omit repeated low-value nested objects and redact token-bearing URLs or other sensitive strings; full mode is the escape hatch when exact raw values are required.", + "actionable-errors": "Validation and common not-found payloads prefer plain-language messages and may include a hint that points to the next useful tool or query to recover.", "resource-bound-access-token": "MCP transport requests must present a resource-bound Bearer token for /api/mcp. JWT access tokens minted with resource=/api/mcp are accepted; opaque tokens minted without a resource indicator are rejected because the server cannot prove MCP audience binding from those token records.", "flexible-date-inputs": "Date and datetime parameters on MCP tools accept ISO 8601 with timezone offset (2026-05-01T08:00:00+08:00), bare date strings (2026-05-01, treated as UTC midnight for @db.Date columns), or timezone-less datetimes (2026-05-01T08:00:00, interpreted as Asia/Shanghai). Invalid strings produce a descriptive error response rather than a validation rejection.", "time-override": "Time-sensitive tools (get_my_7days_timeline, get_upcoming_deadlines, get_my_overview, get_next_buses) accept an optional atTime parameter to anchor their internal clock to a caller-supplied moment instead of the server clock, enabling reproducible queries and future-scenario planning." @@ -110,7 +111,6 @@ "tools": [ "list_my_homeworks", "set_my_homework_completion", - "unset_my_homework_completion", "list_my_schedules", "list_my_exams", "list_homeworks_by_section", diff --git a/src/lib/mcp/compact-bus.ts b/src/lib/mcp/compact-bus.ts new file mode 100644 index 00000000..00b3003d --- /dev/null +++ b/src/lib/mcp/compact-bus.ts @@ -0,0 +1,90 @@ +import { isRecord } from "@/lib/utils"; +import { compactCampus } from "./compact-entities"; +import { + compactArrayRelations, + compactRelations, + pick, +} from "./compact-helpers"; + +function compactBusRouteStop(value: unknown) { + if (!isRecord(value)) return value; + if (Object.hasOwn(value, "campus")) { + return { stopOrder: value.stopOrder, campus: compactCampus(value.campus) }; + } + return pick(value, ["stopOrder", "campusId", "campusName"]); +} + +export function compactBusRoute(value: unknown) { + if (!isRecord(value)) return value; + return { + ...pick(value, [ + "id", + "nameCn", + "nameEn", + "descriptionPrimary", + "descriptionSecondary", + "routeId", + "weekdayTrips", + "weekendTrips", + "stopCount", + ]), + ...compactArrayRelations(value, { stops: compactBusRouteStop }), + ...compactRelations(value, { + originCampus: (v) => compactCampus(v), + destinationCampus: (v) => compactCampus(v), + }), + }; +} + +function compactBusStopTimes(value: unknown): unknown { + if (Array.isArray(value)) { + return value.map((item) => + isRecord(item) + ? pick(item, [ + "stopOrder", + "campusId", + "campusName", + "time", + "minutesSinceMidnight", + "isPassThrough", + ]) + : item, + ); + } + return value; +} + +export function compactBusTrip(value: unknown) { + if (!isRecord(value)) return value; + return { + ...pick(value, [ + "id", + "tripId", + "routeId", + "dayType", + "position", + "departureTime", + "arrivalTime", + "departureMinutes", + "arrivalMinutes", + "minutesUntilDeparture", + "status", + "departureEstimated", + "arrivalEstimated", + ]), + ...compactRelations(value, { + stopTimes: compactBusStopTimes, + route: compactBusRoute, + originCampus: (v) => compactCampus(v), + destinationCampus: (v) => compactCampus(v), + }), + }; +} + +export function compactBusTripSlot(value: unknown) { + if (!isRecord(value)) return value; + return { + position: value.position, + stopTimes: compactBusStopTimes(value.stopTimes), + }; +} diff --git a/src/lib/mcp/compact-calendar.ts b/src/lib/mcp/compact-calendar.ts new file mode 100644 index 00000000..269a912d --- /dev/null +++ b/src/lib/mcp/compact-calendar.ts @@ -0,0 +1,25 @@ +import { isRecord } from "@/lib/utils"; +import { compactSection } from "./compact-entities"; +import { asRecordArray, redactCalendarFeedLocation } from "./compact-helpers"; + +export function compactCalendarSubscription(value: unknown) { + if (!isRecord(value)) return value; + const sections = + Object.hasOwn(value, "sections") && Array.isArray(value.sections) + ? asRecordArray(value.sections).map(compactSection) + : []; + return { + userId: value.userId, + sectionCount: sections.length, + sections, + calendarPath: + typeof value.calendarPath === "string" + ? redactCalendarFeedLocation(value.calendarPath) + : null, + calendarUrl: + typeof value.calendarUrl === "string" + ? redactCalendarFeedLocation(value.calendarUrl) + : null, + note: value.note, + }; +} diff --git a/src/lib/mcp/compact-entities.ts b/src/lib/mcp/compact-entities.ts new file mode 100644 index 00000000..4c968871 --- /dev/null +++ b/src/lib/mcp/compact-entities.ts @@ -0,0 +1,269 @@ +import { isRecord } from "@/lib/utils"; +import { + asRecordArray, + compactArrayRelations, + compactRelations, + pick, + transferScalarKeys, +} from "./compact-helpers"; + +export function compactUser(value: unknown) { + if (!isRecord(value)) return value; + return pick(value, ["id", "name", "username", "image"]); +} + +export function compactDepartment(value: unknown) { + if (!isRecord(value)) return value; + return pick(value, [ + "id", + "nameCn", + "nameEn", + "namePrimary", + "nameSecondary", + ]); +} + +export function compactTeacherTitle(value: unknown) { + if (!isRecord(value)) return value; + return pick(value, [ + "id", + "nameCn", + "nameEn", + "namePrimary", + "nameSecondary", + ]); +} + +export function compactCourse(value: unknown) { + if (!isRecord(value)) return value; + return pick(value, [ + "id", + "jwId", + "code", + "nameCn", + "nameEn", + "namePrimary", + "nameSecondary", + "credit", + "hours", + ]); +} + +export function compactSemester(value: unknown) { + if (!isRecord(value)) return value; + return pick(value, [ + "id", + "jwId", + "code", + "nameCn", + "namePrimary", + "startDate", + "endDate", + ]); +} + +export function compactCampus( + value: unknown, + options?: { includeCoordinates?: boolean }, +) { + if (!isRecord(value)) return value; + const base = pick(value, [ + "id", + "nameCn", + "nameEn", + "namePrimary", + "nameSecondary", + ]); + if (options?.includeCoordinates) { + return { ...base, ...transferScalarKeys(value, ["latitude", "longitude"]) }; + } + return base; +} + +export function compactTeacher(value: unknown) { + if (!isRecord(value)) return value; + return { + ...pick(value, [ + "id", + "personId", + "teacherId", + "code", + "jwId", + "nameCn", + "nameEn", + "namePrimary", + "nameSecondary", + ]), + ...compactRelations(value, { + department: compactDepartment, + teacherTitle: compactTeacherTitle, + }), + ...transferScalarKeys(value, ["_count"]), + ...compactArrayRelations(value, { sections: compactSection }), + }; +} + +export function compactSection(value: unknown) { + if (!isRecord(value)) return value; + return { + ...pick(value, [ + "id", + "jwId", + "code", + "namePrimary", + "nameSecondary", + "campusId", + "openDepartmentId", + ]), + ...compactRelations(value, { + course: compactCourse, + semester: compactSemester, + campus: compactCampus, + openDepartment: compactDepartment, + }), + ...compactArrayRelations(value, { teachers: compactTeacher }), + }; +} + +export function compactTodo(value: unknown) { + if (!isRecord(value)) return value; + const base = pick(value, [ + "id", + "title", + "priority", + "dueAt", + "completed", + "createdAt", + "updatedAt", + ]); + if (!value.completed && Object.hasOwn(value, "content")) { + return { ...base, content: value.content }; + } + return base; +} + +export function compactHomework(value: unknown) { + if (!isRecord(value)) return value; + return { + ...pick(value, [ + "id", + "sectionId", + "title", + "isMajor", + "requiresTeam", + "publishedAt", + "submissionStartAt", + "submissionDueAt", + "deletedAt", + "createdAt", + "updatedAt", + ]), + ...compactRelations(value, { + description: (v) => + isRecord(v) + ? pick(v, ["id", "content", "lastEditedAt", "lastEditedById"]) + : v, + section: compactSection, + createdBy: compactUser, + updatedBy: compactUser, + deletedBy: compactUser, + }), + ...transferScalarKeys(value, [ + "completion", + "commentCount", + "homeworkCompletions", + ]), + }; +} + +export function compactSchedule(value: unknown) { + if (!isRecord(value)) return value; + const base = pick(value, [ + "id", + "jwId", + "date", + "weekday", + "startTime", + "endTime", + "weekIndex", + "createdAt", + "updatedAt", + "customPlace", + ]); + + if (Object.hasOwn(value, "room") && isRecord(value.room)) { + const room = value.room; + const roomOut: Record = pick(room, [ + "id", + "jwId", + "namePrimary", + "nameSecondary", + ]); + if (Object.hasOwn(room, "building") && isRecord(room.building)) { + const bldg = room.building; + const bldgOut: Record = pick(bldg, [ + "id", + "jwId", + "namePrimary", + "nameSecondary", + ]); + if (Object.hasOwn(bldg, "campus")) { + bldgOut.campus = compactCampus(bldg.campus); + } + roomOut.building = bldgOut; + } + return { + ...base, + ...(Object.hasOwn(value, "section") && isRecord(value.section) + ? { section: compactSection(value.section) } + : {}), + room: roomOut, + ...compactArrayRelations(value, { teachers: compactTeacher }), + }; + } + + return { + ...base, + ...compactRelations(value, { section: compactSection }), + ...compactArrayRelations(value, { teachers: compactTeacher }), + }; +} + +export function compactExam(value: unknown) { + if (!isRecord(value)) return value; + return { + ...pick(value, [ + "id", + "jwId", + "examDate", + "startTime", + "endTime", + "createdAt", + "updatedAt", + "examType", + "examMode", + "examTakeCount", + ]), + ...compactRelations(value, { + section: compactSection, + examBatch: (v) => + isRecord(v) + ? pick(v, ["id", "jwId", "namePrimary", "nameSecondary"]) + : v, + }), + ...(Object.hasOwn(value, "examRooms") && Array.isArray(value.examRooms) + ? { + examRooms: asRecordArray(value.examRooms).map((room) => + pick(room, [ + "id", + "jwId", + "roomName", + "buildingName", + "room", + "count", + ]), + ), + } + : {}), + }; +} diff --git a/src/lib/mcp/compact-helpers.ts b/src/lib/mcp/compact-helpers.ts new file mode 100644 index 00000000..df4c88e3 --- /dev/null +++ b/src/lib/mcp/compact-helpers.ts @@ -0,0 +1,66 @@ +import { isRecord } from "@/lib/utils"; + +export function asRecordArray(value: unknown): Record[] { + if (!Array.isArray(value)) return []; + return value.filter(isRecord); +} + +export function pick, K extends keyof T>( + value: T, + keys: readonly K[], +): Pick { + const out = {} as Pick; + for (const key of keys) { + if (Object.hasOwn(value, key) && value[key] !== undefined) { + out[key] = value[key]; + } + } + return out; +} + +export function compactRelations( + source: Record, + relations: Record unknown>, +): Record { + const out: Record = {}; + for (const [key, fn] of Object.entries(relations)) { + if (Object.hasOwn(source, key)) { + out[key] = fn(source[key]); + } + } + return out; +} + +export function compactArrayRelations( + source: Record, + arrayRelations: Record unknown>, +): Record { + const out: Record = {}; + for (const [key, fn] of Object.entries(arrayRelations)) { + if (Object.hasOwn(source, key) && Array.isArray(source[key])) { + out[key] = asRecordArray(source[key]).map(fn); + } + } + return out; +} + +export function transferScalarKeys( + source: Record, + keys: readonly string[], +): Record { + const out: Record = {}; + for (const key of keys) { + if (Object.hasOwn(source, key)) { + out[key] = source[key]; + } + } + return out; +} + +export function redactCalendarFeedLocation(value: string | null | undefined) { + if (!value) return value ?? null; + return value.replace( + /(\/api\/users\/[^/:]+:)([^/?#]+)(\/calendar\.ics)/, + "$1[redacted]$3", + ); +} diff --git a/src/lib/mcp/compact-payload.ts b/src/lib/mcp/compact-payload.ts index 84697ebf..a0d73d41 100644 --- a/src/lib/mcp/compact-payload.ts +++ b/src/lib/mcp/compact-payload.ts @@ -1,541 +1,185 @@ -function isRecord(value: unknown): value is Record { - return ( - typeof value === "object" && - value !== null && - !Array.isArray(value) && - !(value instanceof Date) - ); -} - -function asRecordArray(value: unknown): Record[] { - if (!Array.isArray(value)) return []; - return value.filter(isRecord); -} - -function pick, K extends keyof T>( - value: T, - keys: readonly K[], -): Pick { - const out = {} as Pick; - for (const key of keys) { - if (Object.hasOwn(value, key) && value[key] !== undefined) { - out[key] = value[key]; - } - } - return out; -} - -export function redactCalendarFeedLocation(value: string | null | undefined) { - if (!value) return value ?? null; - return value.replace( - /(\/api\/users\/[^/:]+:)([^/?#]+)(\/calendar\.ics)/, - "$1[redacted]$3", - ); -} - -export function compactUser(value: unknown) { - if (!isRecord(value)) return value; - return pick(value, ["id", "name", "username", "image"]); -} +import { isRecord } from "@/lib/utils"; +import { + compactBusRoute, + compactBusTrip, + compactBusTripSlot, +} from "./compact-bus"; +import { compactCalendarSubscription } from "./compact-calendar"; +import { + compactCampus, + compactCourse, + compactDepartment, + compactExam, + compactHomework, + compactSchedule, + compactSection, + compactSemester, + compactTeacher, + compactTeacherTitle, + compactTodo, + compactUser, +} from "./compact-entities"; +import { + asRecordArray, + pick, + redactCalendarFeedLocation, +} from "./compact-helpers"; -export function compactDepartment(value: unknown) { - if (!isRecord(value)) return value; - return pick(value, [ - "id", - "nameCn", - "nameEn", - "namePrimary", - "nameSecondary", - ]); -} +export { + compactBusRoute, + compactBusTrip, + compactBusTripSlot, +} from "./compact-bus"; +export { compactCalendarSubscription } from "./compact-calendar"; +export { + compactCampus, + compactCourse, + compactDepartment, + compactExam, + compactHomework, + compactSchedule, + compactSection, + compactSemester, + compactTeacher, + compactTeacherTitle, + compactTodo, + compactUser, +} from "./compact-entities"; +// Re-export everything that external consumers need. +// Keep this file as the canonical import surface for backward compatibility. +export { + asRecordArray, + pick, + redactCalendarFeedLocation, +} from "./compact-helpers"; -export function compactTeacherTitle(value: unknown) { - if (!isRecord(value)) return value; - return pick(value, [ - "id", - "nameCn", - "nameEn", - "namePrimary", - "nameSecondary", - ]); -} +/* ------------------------------------------------------------------ */ +/* Array item dispatch — unique discriminator fields */ +/* */ +/* Each entity type has at least one field that NO other type has. */ +/* Using unique discriminators eliminates ordering dependency and */ +/* makes the dispatch deterministic (no silent misidentification). */ +/* ------------------------------------------------------------------ */ -export function compactCourse(value: unknown) { - if (!isRecord(value)) return value; - return pick(value, [ - "id", - "jwId", - "code", - "nameCn", - "nameEn", - "namePrimary", - "nameSecondary", - "credit", - "hours", - ]); -} - -export function compactSemester(value: unknown) { - if (!isRecord(value)) return value; - return pick(value, [ - "id", - "jwId", - "code", - "nameCn", - "namePrimary", - "startDate", - "endDate", - ]); -} +function compactArrayItem(value: unknown): unknown { + if (!isRecord(value)) return compactMcpPayload(value); -export function compactCampus( - value: unknown, - options?: { - includeCoordinates?: boolean; - }, -) { - if (!isRecord(value)) return value; - const out: Record = pick(value, [ - "id", - "nameCn", - "nameEn", - "namePrimary", - "nameSecondary", - ]); - if (options?.includeCoordinates) { - if (Object.hasOwn(value, "latitude")) out.latitude = value.latitude; - if (Object.hasOwn(value, "longitude")) out.longitude = value.longitude; + if ( + Object.hasOwn(value, "sections") && + (Object.hasOwn(value, "calendarPath") || + Object.hasOwn(value, "calendarUrl")) + ) { + return compactCalendarSubscription(value); } - return out; -} -export function compactTeacher(value: unknown) { - if (!isRecord(value)) return value; - const out: Record = pick(value, [ - "id", - "personId", - "teacherId", - "code", - "jwId", - "nameCn", - "nameEn", - "namePrimary", - "nameSecondary", - ]); - if (Object.hasOwn(value, "department")) { - out.department = compactDepartment(value.department); - } - if (Object.hasOwn(value, "teacherTitle")) { - out.teacherTitle = compactTeacherTitle(value.teacherTitle); - } - if (Object.hasOwn(value, "_count")) { - out._count = value._count; + if ( + Object.hasOwn(value, "routeId") && + (value.dayType === "weekday" || value.dayType === "weekend") && + Object.hasOwn(value, "stopTimes") && + Array.isArray(value.stopTimes) + ) { + return compactBusTrip(value); } - if (Object.hasOwn(value, "sections") && Array.isArray(value.sections)) { - out.sections = asRecordArray(value.sections).map(compactSection); - } - return out; -} -export function compactSection(value: unknown) { - if (!isRecord(value)) return value; - const out: Record = pick(value, [ - "id", - "jwId", - "code", - "namePrimary", - "nameSecondary", - "campusId", - "openDepartmentId", - ]); - if (Object.hasOwn(value, "course")) out.course = compactCourse(value.course); - if (Object.hasOwn(value, "semester")) { - out.semester = compactSemester(value.semester); - } - if (Object.hasOwn(value, "campus")) { - out.campus = compactCampus(value.campus); + if ( + Object.hasOwn(value, "position") && + Array.isArray(value.stopTimes) && + !Object.hasOwn(value, "routeId") && + !Object.hasOwn(value, "dayType") + ) { + return compactBusTripSlot(value); } - if (Object.hasOwn(value, "openDepartment")) { - out.openDepartment = compactDepartment(value.openDepartment); - } - if (Object.hasOwn(value, "teachers") && Array.isArray(value.teachers)) { - out.teachers = asRecordArray(value.teachers).map(compactTeacher); - } - return out; -} -export function compactTodo(value: unknown) { - if (!isRecord(value)) return value; - const out: Record = pick(value, [ - "id", - "title", - "priority", - "dueAt", - "completed", - "createdAt", - "updatedAt", - ]); - if (!value.completed && Object.hasOwn(value, "content")) { - out.content = value.content; + if ( + Object.hasOwn(value, "stops") && + Array.isArray(value.stops) && + Object.hasOwn(value, "routeId") && + typeof value.routeId === "string" && + !Object.hasOwn(value, "dayType") + ) { + return compactBusRoute(value); } - return out; -} -export function compactHomework(value: unknown) { - if (!isRecord(value)) return value; - const out: Record = pick(value, [ - "id", - "sectionId", - "title", - "isMajor", - "requiresTeam", - "publishedAt", - "submissionStartAt", - "submissionDueAt", - "deletedAt", - "createdAt", - "updatedAt", - ]); - - if (Object.hasOwn(value, "description")) { - if (isRecord(value.description)) { - out.description = pick(value.description, [ - "id", - "content", - "lastEditedAt", - "lastEditedById", - ]); - } else { - out.description = value.description; - } - } - if (Object.hasOwn(value, "completion")) out.completion = value.completion; - if (Object.hasOwn(value, "commentCount")) - out.commentCount = value.commentCount; - if (Object.hasOwn(value, "homeworkCompletions")) { - out.homeworkCompletions = value.homeworkCompletions; - } - if (Object.hasOwn(value, "section")) { - out.section = compactSection(value.section); - } - if (Object.hasOwn(value, "createdBy")) { - out.createdBy = compactUser(value.createdBy); - } - if (Object.hasOwn(value, "updatedBy")) { - out.updatedBy = compactUser(value.updatedBy); - } - if (Object.hasOwn(value, "deletedBy")) { - out.deletedBy = compactUser(value.deletedBy); + if ( + Object.hasOwn(value, "latitude") && + Object.hasOwn(value, "longitude") && + !Object.hasOwn(value, "stops") + ) { + return compactCampus(value, { includeCoordinates: true }); } - return out; -} - -export function compactSchedule(value: unknown) { - if (!isRecord(value)) return value; - const out: Record = pick(value, [ - "id", - "jwId", - "date", - "weekday", - "startTime", - "endTime", - "weekIndex", - "createdAt", - "updatedAt", - "customPlace", - ]); - if (Object.hasOwn(value, "section")) { - out.section = compactSection(value.section); - } - if (Object.hasOwn(value, "room") && isRecord(value.room)) { - const room = value.room; - out.room = { - ...pick(room, ["id", "jwId", "namePrimary", "nameSecondary"]), - ...(Object.hasOwn(room, "building") && isRecord(room.building) - ? { - building: { - ...pick(room.building, [ - "id", - "jwId", - "namePrimary", - "nameSecondary", - ]), - ...(Object.hasOwn(room.building, "campus") - ? { campus: compactCampus(room.building.campus) } - : {}), - }, - } - : {}), - }; - } - if (Object.hasOwn(value, "teachers") && Array.isArray(value.teachers)) { - out.teachers = asRecordArray(value.teachers).map(compactTeacher); + if ( + Object.hasOwn(value, "teacherId") || + Object.hasOwn(value, "personId") || + Object.hasOwn(value, "teacherTitleId") || + Object.hasOwn(value, "departmentId") + ) { + return compactTeacher(value); } - return out; -} -export function compactExam(value: unknown) { - if (!isRecord(value)) return value; - const out: Record = pick(value, [ - "id", - "jwId", - "examDate", - "startTime", - "endTime", - "createdAt", - "updatedAt", - "examType", - "examMode", - "examTakeCount", - ]); - if (Object.hasOwn(value, "section")) - out.section = compactSection(value.section); - if (Object.hasOwn(value, "examBatch") && isRecord(value.examBatch)) { - out.examBatch = pick(value.examBatch, [ - "id", - "jwId", - "namePrimary", - "nameSecondary", - ]); + if (Object.hasOwn(value, "completed") && Object.hasOwn(value, "priority")) { + return compactTodo(value); } - if (Object.hasOwn(value, "examRooms") && Array.isArray(value.examRooms)) { - out.examRooms = asRecordArray(value.examRooms).map((room) => - pick(room, ["id", "jwId", "roomName", "buildingName", "room", "count"]), - ); - } - return out; -} -function compactBusRouteStop(value: unknown) { - if (!isRecord(value)) return value; - if (Object.hasOwn(value, "campus")) { - return { - stopOrder: value.stopOrder, - campus: compactCampus(value.campus), - }; + if ( + Object.hasOwn(value, "submissionDueAt") && + (Object.hasOwn(value, "sectionId") || + Object.hasOwn(value, "isMajor") || + Object.hasOwn(value, "requiresTeam")) + ) { + return compactHomework(value); } - return pick(value, ["stopOrder", "campusId", "campusName"]); -} -export function compactBusRoute(value: unknown) { - if (!isRecord(value)) return value; - const out: Record = pick(value, [ - "id", - "nameCn", - "nameEn", - "descriptionPrimary", - "descriptionSecondary", - "routeId", - "weekdayTrips", - "weekendTrips", - "stopCount", - ]); - if (Object.hasOwn(value, "stops") && Array.isArray(value.stops)) { - out.stops = asRecordArray(value.stops).map(compactBusRouteStop); - } - if (Object.hasOwn(value, "originCampus")) { - out.originCampus = compactCampus(value.originCampus); - } - if (Object.hasOwn(value, "destinationCampus")) { - out.destinationCampus = compactCampus(value.destinationCampus); + if ( + (Object.hasOwn(value, "examDate") || + Object.hasOwn(value, "examBatch") || + Object.hasOwn(value, "examRooms")) && + Object.hasOwn(value, "sectionId") + ) { + return compactExam(value); } - return out; -} - -function compactBusStopTimes(value: unknown): unknown { - if (Array.isArray(value)) return value.map(compactMcpPayload); - return value; -} -export function compactBusTrip(value: unknown) { - if (!isRecord(value)) return value; - const out: Record = pick(value, [ - "id", - "tripId", - "routeId", - "dayType", - "position", - "departureTime", - "arrivalTime", - "departureMinutes", - "arrivalMinutes", - "minutesUntilDeparture", - "status", - "departureEstimated", - "arrivalEstimated", - ]); - if (Object.hasOwn(value, "stopTimes")) { - out.stopTimes = compactBusStopTimes(value.stopTimes); - } - if (Object.hasOwn(value, "route")) { - out.route = compactBusRoute(value.route); + if ( + Object.hasOwn(value, "date") && + Object.hasOwn(value, "weekday") && + Object.hasOwn(value, "startTime") && + Object.hasOwn(value, "endTime") + ) { + return compactSchedule(value); } - if (Object.hasOwn(value, "originCampus")) { - out.originCampus = compactCampus(value.originCampus); + + if ( + Object.hasOwn(value, "campusId") || + Object.hasOwn(value, "openDepartmentId") || + (Object.hasOwn(value, "course") && Object.hasOwn(value, "semester")) + ) { + return compactSection(value); } - if (Object.hasOwn(value, "destinationCampus")) { - out.destinationCampus = compactCampus(value.destinationCampus); + + if ( + Object.hasOwn(value, "credit") || + Object.hasOwn(value, "hours") || + Object.hasOwn(value, "educationLevelId") + ) { + return compactCourse(value); } - return out; -} -function compactBusTripSlot(value: unknown) { - if (!isRecord(value)) return value; - return { - position: value.position, - stopTimes: compactBusStopTimes(value.stopTimes), - }; -} + if ( + Object.hasOwn(value, "nameCn") && + Object.hasOwn(value, "code") && + (Object.hasOwn(value, "startDate") || Object.hasOwn(value, "endDate")) && + !Object.hasOwn(value, "campusId") + ) { + return compactSemester(value); + } -export function compactCalendarSubscription(value: unknown) { - if (!isRecord(value)) return value; - const sections = asRecordArray(value.sections).map(compactSection); - return { - userId: value.userId, - sectionCount: sections.length, - sections, - calendarPath: - typeof value.calendarPath === "string" - ? redactCalendarFeedLocation(value.calendarPath) - : null, - calendarUrl: - typeof value.calendarUrl === "string" - ? redactCalendarFeedLocation(value.calendarUrl) - : null, - note: value.note, - }; + return compactMcpPayload(value); } -// Explicit ordered dispatch registry. -// Order matters when shapes overlap (e.g. BusTrip before BusTripSlot). -const COMPACT_DISPATCH: { - test: (v: Record) => boolean; - compact: (v: unknown) => unknown; -}[] = [ - { - test: (v) => - Object.hasOwn(v, "sections") && - (Object.hasOwn(v, "calendarPath") || Object.hasOwn(v, "calendarUrl")), - compact: compactCalendarSubscription, - }, - { - // BusTrip: has routeId + position + dayType (weekday|weekend) + stopTimes array - test: (v) => - Object.hasOwn(v, "routeId") && - Object.hasOwn(v, "position") && - (v.dayType === "weekday" || v.dayType === "weekend") && - Object.hasOwn(v, "stopTimes") && - Array.isArray(v.stopTimes), - compact: compactBusTrip, - }, - { - // BusTripSlot: has position + stopTimes but no routeId (distinguishes from BusTrip) - test: (v) => - Object.hasOwn(v, "position") && - Array.isArray(v.stopTimes) && - !Object.hasOwn(v, "routeId"), - compact: compactBusTripSlot, - }, - { - // BusRoute: has stops array where each stop has stopOrder + campus/campusId - test: (v) => { - if (!Object.hasOwn(v, "stops") || !Array.isArray(v.stops)) return false; - if (!Object.hasOwn(v, "id")) return false; - if ( - !Object.hasOwn(v, "nameCn") && - !Object.hasOwn(v, "descriptionPrimary") - ) - return false; - const stops = v.stops.filter(isRecord); - if (stops.length === 0) return false; - return stops.every( - (stop) => - Object.hasOwn(stop, "stopOrder") && - (Object.hasOwn(stop, "campus") || Object.hasOwn(stop, "campusId")), - ); - }, - compact: compactBusRoute, - }, - { - // BusCampus: has nameCn + coordinates but no stops - test: (v) => - Object.hasOwn(v, "nameCn") && - Object.hasOwn(v, "latitude") && - Object.hasOwn(v, "longitude") && - !Object.hasOwn(v, "stops"), - compact: (v) => compactCampus(v, { includeCoordinates: true }), - }, - { - test: (v) => - Object.hasOwn(v, "nameCn") && - (Object.hasOwn(v, "teacherId") || - Object.hasOwn(v, "personId") || - Object.hasOwn(v, "teacherTitleId") || - Object.hasOwn(v, "departmentId")), - compact: compactTeacher, - }, - { - test: (v) => - Object.hasOwn(v, "code") && - (Object.hasOwn(v, "courseId") || - Object.hasOwn(v, "semesterId") || - Object.hasOwn(v, "campusId") || - Object.hasOwn(v, "openDepartmentId") || - (Object.hasOwn(v, "course") && Object.hasOwn(v, "semester"))), - compact: compactSection, - }, - { - test: (v) => - Object.hasOwn(v, "code") && - (Object.hasOwn(v, "credit") || - Object.hasOwn(v, "hours") || - Object.hasOwn(v, "educationLevelId") || - Object.hasOwn(v, "sections")), - compact: compactCourse, - }, - { - test: (v) => - Object.hasOwn(v, "jwId") && - Object.hasOwn(v, "code") && - Object.hasOwn(v, "nameCn") && - (Object.hasOwn(v, "startDate") || Object.hasOwn(v, "endDate")), - compact: compactSemester, - }, - { - test: (v) => - Object.hasOwn(v, "title") && - Object.hasOwn(v, "submissionDueAt") && - (Object.hasOwn(v, "sectionId") || - Object.hasOwn(v, "requiresTeam") || - Object.hasOwn(v, "isMajor")), - compact: compactHomework, - }, - { - test: (v) => - Object.hasOwn(v, "date") && - Object.hasOwn(v, "weekday") && - Object.hasOwn(v, "startTime") && - Object.hasOwn(v, "endTime") && - (Object.hasOwn(v, "sectionId") || Object.hasOwn(v, "weekIndex")), - compact: compactSchedule, - }, - { - test: (v) => - Object.hasOwn(v, "sectionId") && - (Object.hasOwn(v, "examDate") || - Object.hasOwn(v, "examBatch") || - Object.hasOwn(v, "examRooms")), - compact: compactExam, - }, - { - test: (v) => Object.hasOwn(v, "completed") && Object.hasOwn(v, "priority"), - compact: compactTodo, - }, -]; +/* ------------------------------------------------------------------ */ +/* Recursive compactor (top-level) */ +/* ------------------------------------------------------------------ */ -// Per-key compactors for unknown wrapper objects (e.g. { section: ..., user: ... }). -// When none of the top-level dispatch entries match, each known key is compacted individually. const KEY_COMPACTORS: Record unknown> = { calendarPath: (v) => typeof v === "string" ? redactCalendarFeedLocation(v) : v, @@ -556,7 +200,6 @@ const KEY_COMPACTORS: Record unknown> = { section: compactSection, }; -// Per-key array compactors for known plural array fields. const ARRAY_KEY_COMPACTORS: Record unknown> = { todos: compactTodo, courses: compactCourse, @@ -567,16 +210,20 @@ const ARRAY_KEY_COMPACTORS: Record unknown> = { exams: compactExam, routes: compactBusRoute, trips: compactBusTrip, + subscriptions: compactCalendarSubscription, +}; + +const EVENT_PAYLOAD_COMPACTORS: Record unknown> = { + schedule: compactSchedule, + homework_due: compactHomework, + exam: compactExam, + todo_due: compactTodo, }; export function compactMcpPayload(value: unknown): unknown { - if (Array.isArray(value)) return value.map(compactMcpPayload); + if (Array.isArray(value)) return value.map(compactArrayItem); if (!isRecord(value)) return value; - for (const { test, compact } of COMPACT_DISPATCH) { - if (test(value)) return compact(value); - } - const out: Record = {}; for (const [key, fieldValue] of Object.entries(value)) { if (Object.hasOwn(KEY_COMPACTORS, key)) { @@ -588,8 +235,8 @@ export function compactMcpPayload(value: unknown): unknown { continue; } if (key === "campuses" && Array.isArray(fieldValue)) { - out.campuses = asRecordArray(fieldValue).map((campus) => - compactCampus(campus, { includeCoordinates: true }), + out.campuses = asRecordArray(fieldValue).map((c) => + compactCampus(c, { includeCoordinates: true }), ); continue; } @@ -611,20 +258,14 @@ export function compactMcpPayload(value: unknown): unknown { out.events = asRecordArray(fieldValue).map((event) => { const base = pick(event, ["type", "at"]); if (!Object.hasOwn(event, "payload")) return base; - const type = event.type; - if (type === "schedule") { - return { ...base, payload: compactSchedule(event.payload) }; - } - if (type === "homework_due") { - return { ...base, payload: compactHomework(event.payload) }; - } - if (type === "exam") { - return { ...base, payload: compactExam(event.payload) }; - } - if (type === "todo_due") { - return { ...base, payload: compactTodo(event.payload) }; - } - return { ...base, payload: compactMcpPayload(event.payload) }; + const compactFn = + typeof event.type === "string" + ? EVENT_PAYLOAD_COMPACTORS[event.type] + : undefined; + return { + ...base, + payload: (compactFn ?? compactMcpPayload)(event.payload), + }; }); continue; } diff --git a/src/lib/mcp/tools/_helpers.ts b/src/lib/mcp/tools/_helpers.ts index 84de39d0..32e6c793 100644 --- a/src/lib/mcp/tools/_helpers.ts +++ b/src/lib/mcp/tools/_helpers.ts @@ -5,7 +5,11 @@ import { getPrisma, prisma } from "@/lib/db/prisma"; import { compactMcpPayload } from "@/lib/mcp/compact-payload"; import { parseDateInput } from "@/lib/time/parse-date-input"; import { serializeDatesDeep } from "@/lib/time/serialize-date-output"; -import { formatShanghaiDate } from "@/lib/time/shanghai-format"; +import { + formatShanghaiDate, + startOfShanghaiDay, +} from "@/lib/time/shanghai-format"; +import { isRecord } from "@/lib/utils"; export type Locale = z.infer; export const dateTimeSchema = z.string().datetime({ offset: true }); @@ -58,17 +62,33 @@ export function resolveMcpMode( return mode ?? "default"; } -function isRecord(value: unknown): value is Record { - return typeof value === "object" && value !== null && !Array.isArray(value); -} - function summarizeArray(items: unknown[], limit: number) { + const returned = items.length; return { - total: items.length, + total: returned, + returned, + remaining: Math.max(returned - limit, 0), + truncated: returned > limit, items: items.slice(0, limit).map(compactMcpPayload), }; } +function getPaginatedTotal( + key: string, + source: Record, +): number | undefined { + if (key !== "data") { + return undefined; + } + + const pagination = source.pagination; + if (!isRecord(pagination) || typeof pagination.total !== "number") { + return undefined; + } + + return pagination.total; +} + function summarizeMcpPayload(value: unknown): unknown { if (Array.isArray(value)) return summarizeArray(value, 10); if (!isRecord(value)) return value; @@ -76,7 +96,12 @@ function summarizeMcpPayload(value: unknown): unknown { const out: Record = {}; for (const [key, v] of Object.entries(value)) { if (Array.isArray(v)) { - out[key] = summarizeArray(v, key === "events" ? 25 : 10); + const sampleLimit = key === "events" ? 25 : 10; + const total = getPaginatedTotal(key, value); + out[key] = { + ...summarizeArray(v, sampleLimit), + ...(total !== undefined ? { total } : {}), + }; } else { out[key] = compactMcpPayload(v); } @@ -106,6 +131,141 @@ export function jsonToolResult( }; } +export type OptionalFieldDateParseResult = + | { + ok: true; + value: Date | null | undefined; + } + | { + ok: false; + result: ReturnType; + }; + +/** + * Parse an optional date field for MCP tool mutations. + * + * - `undefined` → field not provided (skip, return undefined) + * - `null` → explicitly clear the field (return null) + * - string → parse via parseOptionalMcpDate; return error if invalid + * + * Pass `shouldParse=false` when the field wasn't in the input schema + * to short-circuit without parsing. + */ +export function parseOptionalFieldDate( + fieldName: string, + value: string | null | undefined, + shouldParse = true, +): OptionalFieldDateParseResult { + if (!shouldParse) { + return { ok: true, value: undefined }; + } + if (value === null) { + return { ok: true, value: null }; + } + + const parsed = parseOptionalMcpDate(fieldName, value); + if (!parsed.ok) { + return parsed; + } + + return { ok: true, value: parsed.value ?? null }; +} + +type OptionalMcpDateParseOptions = { + dateOnlyAsShanghaiStart?: boolean; +}; + +type McpDateParseFailure = { + ok: false; + result: ReturnType; +}; + +type OptionalMcpDate = + | { + ok: true; + value?: Date; + dateOnly: boolean; + } + | McpDateParseFailure; + +type McpDateRange = + | { + ok: true; + dateFrom?: Date; + dateTo?: Date; + dateFromIsDateOnly: boolean; + dateToIsDateOnly: boolean; + } + | McpDateParseFailure; + +const MCP_DATE_FILTER_USAGE = "Use YYYY-MM-DD or YYYY-MM-DDTHH:MM:SS+08:00."; +const DATE_ONLY_INPUT_PATTERN = /^\d{4}-\d{2}-\d{2}$/; + +function isDateOnlyInput(value: unknown) { + return ( + typeof value === "string" && DATE_ONLY_INPUT_PATTERN.test(value.trim()) + ); +} + +export function parseOptionalMcpDate( + name: string, + value?: string, + options: OptionalMcpDateParseOptions = {}, +): OptionalMcpDate { + if (!value) { + return { ok: true, dateOnly: false }; + } + + const parsed = parseDateInput(value); + if (!(parsed instanceof Date)) { + return { + ok: false, + result: jsonToolResult({ + success: false, + message: `Invalid ${name}: "${value}". ${MCP_DATE_FILTER_USAGE}`, + }), + }; + } + + const dateOnly = isDateOnlyInput(value); + return { + ok: true, + value: + dateOnly && options.dateOnlyAsShanghaiStart + ? startOfShanghaiDay(parsed) + : parsed, + dateOnly, + }; +} + +type McpDateRangeInput = { + dateFrom?: string; + dateTo?: string; +}; + +export function parseMcpDateRange({ + dateFrom, + dateTo, +}: McpDateRangeInput): McpDateRange { + const parsedDateFrom = parseOptionalMcpDate("dateFrom", dateFrom); + if (!parsedDateFrom.ok) { + return parsedDateFrom; + } + + const parsedDateTo = parseOptionalMcpDate("dateTo", dateTo); + if (!parsedDateTo.ok) { + return parsedDateTo; + } + + return { + ok: true, + dateFrom: parsedDateFrom.value, + dateTo: parsedDateTo.value, + dateFromIsDateOnly: parsedDateFrom.dateOnly, + dateToIsDateOnly: parsedDateTo.dateOnly, + }; +} + export function getUserId(authInfo?: AuthInfo): string { const userId = authInfo?.extra?.userId; if (typeof userId !== "string" || userId.length === 0) { @@ -116,7 +276,7 @@ export function getUserId(authInfo?: AuthInfo): string { } export async function getViewerInfo(userId: string) { - const user = await prisma.user.findUniqueOrThrow({ + const user = await prisma.user.findUnique({ where: { id: userId }, select: { id: true, @@ -126,6 +286,10 @@ export async function getViewerInfo(userId: string) { }, }); + if (!user) { + throw new Error(`User ${userId} not found`); + } + return user; } diff --git a/src/lib/mcp/tools/bus-tools.ts b/src/lib/mcp/tools/bus-tools.ts index b006d327..f9170af1 100644 --- a/src/lib/mcp/tools/bus-tools.ts +++ b/src/lib/mcp/tools/bus-tools.ts @@ -14,6 +14,7 @@ import { jsonToolResult, mcpLocaleInputSchema, mcpModeInputSchema, + parseOptionalMcpDate, resolveMcpMode, } from "@/lib/mcp/tools/_helpers"; import { summarizeBusDeparture } from "@/lib/mcp/tools/event-summary"; @@ -279,11 +280,16 @@ export function registerBusTools(server: McpServer) { extra, ) => { const resolvedMode = resolveMcpMode(mode); + const parsedAtTime = parseOptionalMcpDate("atTime", atTime); + if (!parsedAtTime.ok) { + return parsedAtTime.result; + } + const result = await getNextBusDepartures({ locale, originCampusId, destinationCampusId, - atTime, + atTime: parsedAtTime.value?.toISOString(), dayType, includeDeparted, limit, diff --git a/src/lib/mcp/tools/calendar-tools.ts b/src/lib/mcp/tools/calendar-tools.ts index e5dcc2af..5a26144c 100644 --- a/src/lib/mcp/tools/calendar-tools.ts +++ b/src/lib/mcp/tools/calendar-tools.ts @@ -22,6 +22,7 @@ import { jsonToolResult, mcpLocaleInputSchema, mcpModeInputSchema, + parseMcpDateRange, resolveMcpMode, sectionCodeSchema, } from "@/lib/mcp/tools/_helpers"; @@ -31,15 +32,6 @@ import { } from "@/lib/mcp/tools/calendar-summary"; import { summarizeCalendarEventCollection } from "@/lib/mcp/tools/event-summary"; import { getPublicOrigin } from "@/lib/site-url"; -import { parseDateInput } from "@/lib/time/parse-date-input"; - -const DATE_ONLY_INPUT_PATTERN = /^\d{4}-\d{2}-\d{2}$/; - -function isDateOnlyInput(value: unknown) { - return ( - typeof value === "string" && DATE_ONLY_INPUT_PATTERN.test(value.trim()) - ); -} function getCalendarSubscriptionReadPayload( subscription: NonNullable< @@ -327,26 +319,16 @@ export function registerCalendarTools(server: McpServer) { }, }, async ({ dateFrom, dateTo, locale, mode }, extra) => { - const parsedDateFrom = dateFrom ? parseDateInput(dateFrom) : undefined; - const parsedDateTo = dateTo ? parseDateInput(dateTo) : undefined; - if (parsedDateFrom === undefined && dateFrom) { - return jsonToolResult({ - success: false, - message: `Invalid dateFrom: "${dateFrom}". Use YYYY-MM-DD or YYYY-MM-DDTHH:MM:SS+08:00.`, - }); - } - if (parsedDateTo === undefined && dateTo) { - return jsonToolResult({ - success: false, - message: `Invalid dateTo: "${dateTo}". Use YYYY-MM-DD or YYYY-MM-DDTHH:MM:SS+08:00.`, - }); + const dateRange = parseMcpDateRange({ dateFrom, dateTo }); + if (!dateRange.ok) { + return dateRange.result; } const events = await listUserCalendarEvents(getUserId(extra.authInfo), { locale, - dateFrom: parsedDateFrom instanceof Date ? parsedDateFrom : undefined, - dateTo: parsedDateTo instanceof Date ? parsedDateTo : undefined, - dateFromIsDateOnly: isDateOnlyInput(dateFrom), - dateToIsDateOnly: isDateOnlyInput(dateTo), + dateFrom: dateRange.dateFrom, + dateTo: dateRange.dateTo, + dateFromIsDateOnly: dateRange.dateFromIsDateOnly, + dateToIsDateOnly: dateRange.dateToIsDateOnly, dateToInclusive: true, }); const resolvedMode = resolveMcpMode(mode); diff --git a/src/lib/mcp/tools/course-tools.ts b/src/lib/mcp/tools/course-tools.ts index ff946e27..7e16fb2a 100644 --- a/src/lib/mcp/tools/course-tools.ts +++ b/src/lib/mcp/tools/course-tools.ts @@ -153,8 +153,10 @@ export function registerCourseTools(server: McpServer) { if (!section) { return jsonToolResult({ + success: false, found: false, message: `Section ${jwId} was not found`, + hint: "Use search_sections to find a valid section jwId, or match_section_codes if you only have a section code.", }); } diff --git a/src/lib/mcp/tools/dashboard-tools.ts b/src/lib/mcp/tools/dashboard-tools.ts index 33ef35c0..74bfdd39 100644 --- a/src/lib/mcp/tools/dashboard-tools.ts +++ b/src/lib/mcp/tools/dashboard-tools.ts @@ -8,35 +8,18 @@ import { jsonToolResult, mcpLocaleInputSchema, mcpModeInputSchema, + parseOptionalMcpDate, resolveMcpMode, } from "@/lib/mcp/tools/_helpers"; import { compactDashboardSnapshot, summarizeDashboardSnapshot, } from "@/lib/mcp/tools/dashboard-summary"; -import { parseDateInput } from "@/lib/time/parse-date-input"; -import { startOfShanghaiDay } from "@/lib/time/shanghai-format"; - -const DATE_ONLY_PATTERN = /^\d{4}-\d{2}-\d{2}$/; function parseOptionalAtTime(atTime: string | undefined) { - if (!atTime) return { ok: true as const, value: undefined }; - const parsed = parseDateInput(atTime); - if (!(parsed instanceof Date)) { - return { - ok: false as const, - result: jsonToolResult({ - success: false, - message: `Invalid atTime: "${atTime}". Use YYYY-MM-DD or YYYY-MM-DDTHH:MM:SS+08:00.`, - }), - }; - } - return { - ok: true as const, - value: DATE_ONLY_PATTERN.test(atTime.trim()) - ? startOfShanghaiDay(parsed) - : parsed, - }; + return parseOptionalMcpDate("atTime", atTime, { + dateOnlyAsShanghaiStart: true, + }); } export function registerDashboardTools(server: McpServer) { diff --git a/src/lib/mcp/tools/event-summary.ts b/src/lib/mcp/tools/event-summary.ts index 4c978f90..7f32e5d4 100644 --- a/src/lib/mcp/tools/event-summary.ts +++ b/src/lib/mcp/tools/event-summary.ts @@ -6,22 +6,8 @@ type BusDeparture = NonNullable< Awaited> >["departures"][number]; -function isRecord(value: unknown): value is Record { - return typeof value === "object" && value !== null && !Array.isArray(value); -} - -function pick< - T extends Record, - const K extends readonly (keyof T)[], ->(value: T, keys: K): Pick { - const out = {} as Pick; - for (const key of keys) { - if (Object.hasOwn(value, key) && value[key] !== undefined) { - out[key] = value[key]; - } - } - return out; -} +import { pick } from "@/lib/mcp/compact-payload"; +import { isRecord } from "@/lib/utils"; export function summarizeSectionCard(value: unknown) { if (!isRecord(value)) return value; diff --git a/src/lib/mcp/tools/my-data-tools.ts b/src/lib/mcp/tools/my-data-tools.ts index 2201303c..3ef41d0c 100644 --- a/src/lib/mcp/tools/my-data-tools.ts +++ b/src/lib/mcp/tools/my-data-tools.ts @@ -17,6 +17,8 @@ import { jsonToolResult, mcpLocaleInputSchema, mcpModeInputSchema, + parseMcpDateRange, + parseOptionalMcpDate, resolveMcpMode, } from "@/lib/mcp/tools/_helpers"; import { @@ -25,7 +27,6 @@ import { summarizeHomeworkCard, summarizeTodoCard, } from "@/lib/mcp/tools/event-summary"; -import { parseDateInput } from "@/lib/time/parse-date-input"; import { toShanghaiIsoString } from "@/lib/time/serialize-date-output"; export function registerMyDataTools(server: McpServer) { @@ -63,7 +64,7 @@ export function registerMyDataTools(server: McpServer) { "set_my_homework_completion", { description: - "Mark a homework as completed or incomplete. Prefer this over unset_my_homework_completion — pass completed: false to revert.", + "Mark a homework as completed or incomplete. Pass completed: false to revert to incomplete.", inputSchema: { homeworkId: z.string().trim().min(1), completed: z.boolean(), @@ -82,6 +83,7 @@ export function registerMyDataTools(server: McpServer) { return jsonToolResult({ success: false, message: "Homework not found", + hint: "Use list_my_homeworks or list_homeworks_by_section to confirm the homeworkId before updating completion.", }); } @@ -123,49 +125,6 @@ export function registerMyDataTools(server: McpServer) { }, ); - server.registerTool( - "unset_my_homework_completion", - { - description: - "Revert a completed homework back to incomplete. Equivalent to set_my_homework_completion(completed: false).", - inputSchema: { - homeworkId: z.string().trim().min(1), - mode: mcpModeInputSchema, - }, - }, - async ({ homeworkId, mode }, extra) => { - const resolvedMode = resolveMcpMode(mode); - const userId = getUserId(extra.authInfo); - const homework = await prisma.homework.findUnique({ - where: { id: homeworkId }, - select: { id: true, deletedAt: true }, - }); - - if (!homework || homework.deletedAt) { - return jsonToolResult({ - success: false, - message: "Homework not found", - }); - } - - await prisma.homeworkCompletion.deleteMany({ - where: { userId, homeworkId }, - }); - - return jsonToolResult( - { - success: true, - completion: { - homeworkId, - completed: false, - completedAt: null, - }, - }, - { mode: resolvedMode }, - ); - }, - ); - server.registerTool( "list_my_schedules", { @@ -183,24 +142,14 @@ export function registerMyDataTools(server: McpServer) { async ({ dateFrom, dateTo, weekday, limit, locale, mode }, extra) => { const resolvedMode = resolveMcpMode(mode); const userId = getUserId(extra.authInfo); - const parsedDateFrom = dateFrom ? parseDateInput(dateFrom) : undefined; - if (parsedDateFrom === undefined && dateFrom) { - return jsonToolResult({ - success: false, - message: `Invalid dateFrom: "${dateFrom}". Use YYYY-MM-DD or YYYY-MM-DDTHH:MM:SS+08:00.`, - }); - } - const parsedDateTo = dateTo ? parseDateInput(dateTo) : undefined; - if (parsedDateTo === undefined && dateTo) { - return jsonToolResult({ - success: false, - message: `Invalid dateTo: "${dateTo}". Use YYYY-MM-DD or YYYY-MM-DDTHH:MM:SS+08:00.`, - }); + const dateRange = parseMcpDateRange({ dateFrom, dateTo }); + if (!dateRange.ok) { + return dateRange.result; } const schedules = await listSubscribedSchedules(userId, { locale, - dateFrom: parsedDateFrom instanceof Date ? parsedDateFrom : undefined, - dateTo: parsedDateTo instanceof Date ? parsedDateTo : undefined, + dateFrom: dateRange.dateFrom, + dateTo: dateRange.dateTo, weekday, limit, }); @@ -229,24 +178,14 @@ export function registerMyDataTools(server: McpServer) { ) => { const resolvedMode = resolveMcpMode(mode); const userId = getUserId(extra.authInfo); - const parsedDateFrom = dateFrom ? parseDateInput(dateFrom) : undefined; - if (parsedDateFrom === undefined && dateFrom) { - return jsonToolResult({ - success: false, - message: `Invalid dateFrom: "${dateFrom}". Use YYYY-MM-DD or YYYY-MM-DDTHH:MM:SS+08:00.`, - }); - } - const parsedDateTo = dateTo ? parseDateInput(dateTo) : undefined; - if (parsedDateTo === undefined && dateTo) { - return jsonToolResult({ - success: false, - message: `Invalid dateTo: "${dateTo}". Use YYYY-MM-DD or YYYY-MM-DDTHH:MM:SS+08:00.`, - }); + const dateRange = parseMcpDateRange({ dateFrom, dateTo }); + if (!dateRange.ok) { + return dateRange.result; } const exams = await listSubscribedExams(userId, { locale, - dateFrom: parsedDateFrom instanceof Date ? parsedDateFrom : undefined, - dateTo: parsedDateTo instanceof Date ? parsedDateTo : undefined, + dateFrom: dateRange.dateFrom, + dateTo: dateRange.dateTo, includeDateUnknown, limit, }); @@ -276,58 +215,50 @@ export function registerMyDataTools(server: McpServer) { const userId = getUserId(extra.authInfo); const user = await getViewerInfo(userId); const sectionIds = await getSubscribedSectionIds(userId); - const atTimeDate = atTime - ? (parseDateInput(atTime) ?? undefined) - : undefined; - if (atTime && !(atTimeDate instanceof Date)) { - return jsonToolResult({ - success: false, - message: `Invalid atTime: "${atTime}". Use YYYY-MM-DD or YYYY-MM-DDTHH:MM:SS+08:00.`, - }); + const atTimeDate = parseOptionalMcpDate("atTime", atTime); + if (!atTimeDate.ok) { + return atTimeDate.result; } const { now, todayStart, tomorrowStart } = getTodayBounds( - atTimeDate instanceof Date ? atTimeDate : undefined, + atTimeDate.value, ); - // Prisma's pg adapter currently emits a driver deprecation warning when - // a single request fans out multiple queries concurrently. - const pendingTodosCount = await prisma.todo.count({ - where: { - userId, - completed: false, - }, - }); - const pendingHomeworksCount = + // Run all count queries concurrently. The pg adapter handles + // concurrent queries correctly; the previous sequential pattern + // added unnecessary latency to the overview endpoint. + const [ + pendingTodosCount, + pendingHomeworksCount, + todaySchedulesCount, + upcomingExamsCount, + ] = await Promise.all([ + prisma.todo.count({ where: { userId, completed: false } }), sectionIds.length > 0 - ? await prisma.homework.count({ + ? prisma.homework.count({ where: { deletedAt: null, sectionId: { in: sectionIds }, homeworkCompletions: { none: { userId } }, }, }) - : 0; - const todaySchedulesCount = + : Promise.resolve(0), sectionIds.length > 0 - ? await prisma.schedule.count({ + ? prisma.schedule.count({ where: { sectionId: { in: sectionIds }, - date: { - gte: todayStart, - lt: tomorrowStart, - }, + date: { gte: todayStart, lt: tomorrowStart }, }, }) - : 0; - const upcomingExamsCount = + : Promise.resolve(0), sectionIds.length > 0 - ? await prisma.exam.count({ + ? prisma.exam.count({ where: { sectionId: { in: sectionIds }, examDate: { gte: todayStart }, }, }) - : 0; + : Promise.resolve(0), + ]); const dueTodos = await prisma.todo.findMany({ where: { userId, @@ -439,18 +370,11 @@ export function registerMyDataTools(server: McpServer) { async ({ locale, atTime, mode }, extra) => { const resolvedMode = resolveMcpMode(mode); const userId = getUserId(extra.authInfo); - const atTimeDate = atTime - ? (parseDateInput(atTime) ?? undefined) - : undefined; - if (atTime && !(atTimeDate instanceof Date)) { - return jsonToolResult({ - success: false, - message: `Invalid atTime: "${atTime}". Use YYYY-MM-DD or YYYY-MM-DDTHH:MM:SS+08:00.`, - }); + const atTimeDate = parseOptionalMcpDate("atTime", atTime); + if (!atTimeDate.ok) { + return atTimeDate.result; } - const { todayStart } = getTodayBounds( - atTimeDate instanceof Date ? atTimeDate : undefined, - ); + const { todayStart } = getTodayBounds(atTimeDate.value); const windowEnd = new Date(todayStart); windowEnd.setDate(windowEnd.getDate() + 7); const events = await listUserCalendarEvents(userId, { diff --git a/src/lib/mcp/tools/profile-tools.ts b/src/lib/mcp/tools/profile-tools.ts index c3b5b1d2..60920602 100644 --- a/src/lib/mcp/tools/profile-tools.ts +++ b/src/lib/mcp/tools/profile-tools.ts @@ -1,14 +1,15 @@ import type { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; import * as z from "zod"; +import type { Prisma } from "@/generated/prisma/client"; import { prisma } from "@/lib/db/prisma"; import { getUserId, jsonToolResult, mcpModeInputSchema, + parseOptionalFieldDate, resolveMcpMode, todoPrioritySchema, } from "@/lib/mcp/tools/_helpers"; -import { parseDateInput } from "@/lib/time/parse-date-input"; export function registerProfileTools(server: McpServer) { server.registerTool( @@ -22,7 +23,7 @@ export function registerProfileTools(server: McpServer) { }, async ({ mode }, extra) => { const userId = getUserId(extra.authInfo); - const user = await prisma.user.findUniqueOrThrow({ + const user = await prisma.user.findUnique({ where: { id: userId }, select: { id: true, @@ -35,6 +36,13 @@ export function registerProfileTools(server: McpServer) { }, }); + if (!user) { + return jsonToolResult({ + success: false, + message: "User not found", + }); + } + return jsonToolResult(user, { mode: resolveMcpMode(mode), }); @@ -122,12 +130,9 @@ export function registerProfileTools(server: McpServer) { }, async ({ title, content, priority, dueAt, mode }, extra) => { const userId = getUserId(extra.authInfo); - const parsedDueAt = parseDateInput(dueAt); - if (parsedDueAt === undefined) { - return jsonToolResult({ - success: false, - message: "Invalid due date", - }); + const parsedDueAt = parseOptionalFieldDate("dueAt", dueAt); + if (!parsedDueAt.ok) { + return parsedDueAt.result; } const todo = await prisma.todo.create({ @@ -137,7 +142,7 @@ export function registerProfileTools(server: McpServer) { title, content: content?.trim() || null, priority, - dueAt: parsedDueAt, + dueAt: parsedDueAt.value, }, }); @@ -190,19 +195,16 @@ export function registerProfileTools(server: McpServer) { } const hasDueAt = dueAt !== undefined; - const parsedDueAt = hasDueAt ? parseDateInput(dueAt) : undefined; - if (hasDueAt && parsedDueAt === undefined) { - return jsonToolResult({ - success: false, - message: "Invalid due date", - }); + const parsedDueAt = parseOptionalFieldDate("dueAt", dueAt, hasDueAt); + if (!parsedDueAt.ok) { + return parsedDueAt.result; } - const updates: Record = {}; + const updates: Prisma.TodoUpdateInput = {}; if (title !== undefined) updates.title = title; if (content !== undefined) updates.content = content?.trim() || null; if (priority !== undefined) updates.priority = priority; - if (hasDueAt) updates.dueAt = parsedDueAt; + if (hasDueAt) updates.dueAt = parsedDueAt.value; if (completed !== undefined) updates.completed = completed; if (Object.keys(updates).length === 0) { diff --git a/src/lib/mcp/tools/section-data/homework-tools.ts b/src/lib/mcp/tools/section-data/homework-tools.ts index 3014f242..b455f125 100644 --- a/src/lib/mcp/tools/section-data/homework-tools.ts +++ b/src/lib/mcp/tools/section-data/homework-tools.ts @@ -1,6 +1,7 @@ import type { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; import * as z from "zod"; import { withHomeworkItemState } from "@/features/homeworks/server/homework-item-state"; +import type { Prisma } from "@/generated/prisma/client"; import { DEFAULT_LOCALE } from "@/i18n/config"; import { findActiveSuspension } from "@/lib/auth/viewer-context"; import { getPrisma, prisma } from "@/lib/db/prisma"; @@ -9,11 +10,11 @@ import { jsonToolResult, mcpLocaleInputSchema, mcpModeInputSchema, + parseOptionalFieldDate, resolveMcpMode, resolveSectionByJwId, } from "@/lib/mcp/tools/_helpers"; import { summarizeHomeworkCard } from "@/lib/mcp/tools/event-summary"; -import { parseDateInput } from "@/lib/time/parse-date-input"; import { sectionNotFoundToolResult } from "./shared"; const homeworkToolUserSelect = { @@ -60,6 +61,24 @@ async function getHomeworkItemById( return homeworkItem ?? null; } +function invalidSubmissionWindow( + submissionStartAt: Date | null | undefined, + submissionDueAt: Date | null | undefined, +) { + if ( + submissionStartAt && + submissionDueAt && + submissionStartAt.getTime() > submissionDueAt.getTime() + ) { + return jsonToolResult( + { success: false, message: "Submission start must be before due" }, + { mode: "default" }, + ); + } + + return null; +} + export function registerSectionHomeworkTools(server: McpServer) { server.registerTool( "list_homeworks_by_section", @@ -185,42 +204,36 @@ export function registerSectionHomeworkTools(server: McpServer) { const { section } = await resolveSectionByJwId(sectionJwId, locale); if (!section) { - return jsonToolResult( - { success: false, message: `Section ${sectionJwId} was not found` }, - { mode: resolvedMode }, - ); + return sectionNotFoundToolResult(sectionJwId, resolvedMode); } - const parsedPublishedAt = parseDateInput(publishedAt); - const parsedSubmissionStartAt = parseDateInput(submissionStartAt); - const parsedSubmissionDueAt = parseDateInput(submissionDueAt); - if (parsedPublishedAt === undefined) { - return jsonToolResult( - { success: false, message: "Invalid publish date" }, - { mode: resolvedMode }, - ); + const parsedPublishedAt = parseOptionalFieldDate( + "publishedAt", + publishedAt, + ); + if (!parsedPublishedAt.ok) { + return parsedPublishedAt.result; } - if (parsedSubmissionStartAt === undefined) { - return jsonToolResult( - { success: false, message: "Invalid submission start" }, - { mode: resolvedMode }, - ); + const parsedSubmissionStartAt = parseOptionalFieldDate( + "submissionStartAt", + submissionStartAt, + ); + if (!parsedSubmissionStartAt.ok) { + return parsedSubmissionStartAt.result; } - if (parsedSubmissionDueAt === undefined) { - return jsonToolResult( - { success: false, message: "Invalid submission due" }, - { mode: resolvedMode }, - ); + const parsedSubmissionDueAt = parseOptionalFieldDate( + "submissionDueAt", + submissionDueAt, + ); + if (!parsedSubmissionDueAt.ok) { + return parsedSubmissionDueAt.result; } - if ( - parsedSubmissionStartAt && - parsedSubmissionDueAt && - parsedSubmissionStartAt.getTime() > parsedSubmissionDueAt.getTime() - ) { - return jsonToolResult( - { success: false, message: "Submission start must be before due" }, - { mode: resolvedMode }, - ); + const submissionWindowError = invalidSubmissionWindow( + parsedSubmissionStartAt.value, + parsedSubmissionDueAt.value, + ); + if (submissionWindowError) { + return submissionWindowError; } const trimmedDescription = (description ?? "").trim(); @@ -231,9 +244,9 @@ export function registerSectionHomeworkTools(server: McpServer) { title, isMajor: isMajor === true, requiresTeam: requiresTeam === true, - publishedAt: parsedPublishedAt, - submissionStartAt: parsedSubmissionStartAt, - submissionDueAt: parsedSubmissionDueAt, + publishedAt: parsedPublishedAt.value, + submissionStartAt: parsedSubmissionStartAt.value, + submissionDueAt: parsedSubmissionDueAt.value, createdById: userId, updatedById: userId, }, @@ -332,43 +345,36 @@ export function registerSectionHomeworkTools(server: McpServer) { const hasSubmissionStartAt = submissionStartAt !== undefined; const hasSubmissionDueAt = submissionDueAt !== undefined; - const parsedPublishedAt = hasPublishedAt - ? parseDateInput(publishedAt) - : undefined; - const parsedSubmissionStartAt = hasSubmissionStartAt - ? parseDateInput(submissionStartAt) - : undefined; - const parsedSubmissionDueAt = hasSubmissionDueAt - ? parseDateInput(submissionDueAt) - : undefined; - - if (hasPublishedAt && parsedPublishedAt === undefined) { - return jsonToolResult( - { success: false, message: "Invalid publish date" }, - { mode: resolvedMode }, - ); + const parsedPublishedAt = parseOptionalFieldDate( + "publishedAt", + publishedAt, + hasPublishedAt, + ); + if (!parsedPublishedAt.ok) { + return parsedPublishedAt.result; } - if (hasSubmissionStartAt && parsedSubmissionStartAt === undefined) { - return jsonToolResult( - { success: false, message: "Invalid submission start" }, - { mode: resolvedMode }, - ); + const parsedSubmissionStartAt = parseOptionalFieldDate( + "submissionStartAt", + submissionStartAt, + hasSubmissionStartAt, + ); + if (!parsedSubmissionStartAt.ok) { + return parsedSubmissionStartAt.result; } - if (hasSubmissionDueAt && parsedSubmissionDueAt === undefined) { - return jsonToolResult( - { success: false, message: "Invalid submission due" }, - { mode: resolvedMode }, - ); + const parsedSubmissionDueAt = parseOptionalFieldDate( + "submissionDueAt", + submissionDueAt, + hasSubmissionDueAt, + ); + if (!parsedSubmissionDueAt.ok) { + return parsedSubmissionDueAt.result; } - if ( - parsedSubmissionStartAt && - parsedSubmissionDueAt && - parsedSubmissionStartAt.getTime() > parsedSubmissionDueAt.getTime() - ) { - return jsonToolResult( - { success: false, message: "Submission start must be before due" }, - { mode: resolvedMode }, - ); + const submissionWindowError = invalidSubmissionWindow( + parsedSubmissionStartAt.value, + parsedSubmissionDueAt.value, + ); + if (submissionWindowError) { + return submissionWindowError; } const existing = await prisma.homework.findUnique({ @@ -377,7 +383,11 @@ export function registerSectionHomeworkTools(server: McpServer) { }); if (!existing) { return jsonToolResult( - { success: false, message: "Homework not found" }, + { + success: false, + message: "Homework not found", + hint: "Use list_homeworks_by_section or list_my_homeworks to confirm the homeworkId before updating it.", + }, { mode: resolvedMode }, ); } @@ -388,17 +398,18 @@ export function registerSectionHomeworkTools(server: McpServer) { ); } - const updates: Record = { updatedById: userId }; + const updates: Prisma.HomeworkUncheckedUpdateInput = { + updatedById: userId, + }; if (title !== undefined) updates.title = title; if (isMajor !== undefined) updates.isMajor = isMajor === true; if (requiresTeam !== undefined) updates.requiresTeam = requiresTeam === true; - if (parsedPublishedAt !== undefined) - updates.publishedAt = parsedPublishedAt; - if (parsedSubmissionStartAt !== undefined) - updates.submissionStartAt = parsedSubmissionStartAt; - if (parsedSubmissionDueAt !== undefined) - updates.submissionDueAt = parsedSubmissionDueAt; + if (hasPublishedAt) updates.publishedAt = parsedPublishedAt.value; + if (hasSubmissionStartAt) + updates.submissionStartAt = parsedSubmissionStartAt.value; + if (hasSubmissionDueAt) + updates.submissionDueAt = parsedSubmissionDueAt.value; const wantsDescription = description !== undefined; const trimmedDescription = (description ?? "").trim(); diff --git a/src/lib/mcp/tools/section-data/record-tools.ts b/src/lib/mcp/tools/section-data/record-tools.ts index 97f7419f..aef4f094 100644 --- a/src/lib/mcp/tools/section-data/record-tools.ts +++ b/src/lib/mcp/tools/section-data/record-tools.ts @@ -7,6 +7,7 @@ import { jsonToolResult, mcpLocaleInputSchema, mcpModeInputSchema, + parseMcpDateRange, resolveMcpMode, resolveSectionByJwId, } from "@/lib/mcp/tools/_helpers"; @@ -15,7 +16,6 @@ import { buildScheduleListWhere, publicScheduleInclude, } from "@/lib/schedule-queries"; -import { parseDateInput } from "@/lib/time/parse-date-input"; import { sectionExamInclude, sectionNotFoundToolResult, @@ -68,19 +68,9 @@ export function registerSectionRecordTools(server: McpServer) { }) => { const localizedPrisma = getPrisma(locale); const pagination = normalizePagination({ page, pageSize: limit }); - const parsedDateFrom = dateFrom ? parseDateInput(dateFrom) : undefined; - if (!(parsedDateFrom instanceof Date) && dateFrom) { - return jsonToolResult({ - success: false, - message: `Invalid dateFrom: "${dateFrom}". Use YYYY-MM-DD or YYYY-MM-DDTHH:MM:SS+08:00.`, - }); - } - const parsedDateTo = dateTo ? parseDateInput(dateTo) : undefined; - if (!(parsedDateTo instanceof Date) && dateTo) { - return jsonToolResult({ - success: false, - message: `Invalid dateTo: "${dateTo}". Use YYYY-MM-DD or YYYY-MM-DDTHH:MM:SS+08:00.`, - }); + const dateRange = parseMcpDateRange({ dateFrom, dateTo }); + if (!dateRange.ok) { + return dateRange.result; } const where = buildScheduleListWhere({ sectionId, @@ -91,8 +81,8 @@ export function registerSectionRecordTools(server: McpServer) { roomId, roomJwId, weekday, - dateFrom: parsedDateFrom instanceof Date ? parsedDateFrom : undefined, - dateTo: parsedDateTo instanceof Date ? parsedDateTo : undefined, + dateFrom: dateRange.dateFrom, + dateTo: dateRange.dateTo, }); const [schedules, total] = await Promise.all([ @@ -153,33 +143,17 @@ export function registerSectionRecordTools(server: McpServer) { return sectionNotFoundToolResult(sectionJwId, mode); } - const parsedDateFrom = dateFrom ? parseDateInput(dateFrom) : undefined; - if (parsedDateFrom === undefined && dateFrom) { - return jsonToolResult({ - found: true, - section, - schedules: [], - message: `Invalid dateFrom: "${dateFrom}". Use YYYY-MM-DD or YYYY-MM-DDTHH:MM:SS+08:00.`, - }); - } - const parsedDateTo = dateTo ? parseDateInput(dateTo) : undefined; - if (parsedDateTo === undefined && dateTo) { - return jsonToolResult({ - found: true, - section, - schedules: [], - message: `Invalid dateTo: "${dateTo}". Use YYYY-MM-DD or YYYY-MM-DDTHH:MM:SS+08:00.`, - }); + const dateRange = parseMcpDateRange({ dateFrom, dateTo }); + if (!dateRange.ok) { + return dateRange.result; } const dateFilter = - parsedDateFrom instanceof Date || parsedDateTo instanceof Date + dateRange.dateFrom || dateRange.dateTo ? { date: { - ...(parsedDateFrom instanceof Date - ? { gte: parsedDateFrom } - : {}), - ...(parsedDateTo instanceof Date ? { lte: parsedDateTo } : {}), + ...(dateRange.dateFrom ? { gte: dateRange.dateFrom } : {}), + ...(dateRange.dateTo ? { lte: dateRange.dateTo } : {}), }, } : {}; diff --git a/src/lib/mcp/tools/section-data/shared.ts b/src/lib/mcp/tools/section-data/shared.ts index c7ac0126..d8d90a43 100644 --- a/src/lib/mcp/tools/section-data/shared.ts +++ b/src/lib/mcp/tools/section-data/shared.ts @@ -66,9 +66,11 @@ export function sectionNotFoundToolResult( ) { return jsonToolResult( { + success: false, found: false, message: `Section ${sectionJwId} was not found`, + hint: "Use search_sections to find a valid section jwId, or match_section_codes if you only have a section code.", }, - mode === undefined ? undefined : { mode: resolveMcpMode(mode) }, + { mode: resolveMcpMode(mode) }, ); } diff --git a/src/lib/mcp/urls.ts b/src/lib/mcp/urls.ts index 8eff16e8..d5aa37d6 100644 --- a/src/lib/mcp/urls.ts +++ b/src/lib/mcp/urls.ts @@ -1,11 +1,6 @@ import { getBetterAuthBaseUrl, getPublicOrigin } from "@/lib/site-url"; -export const MCP_ROUTE_PATH = "/api/mcp"; -export const OAUTH_AUTHORIZATION_PATH = "/api/auth/oauth2/authorize"; -export const OAUTH_REGISTRATION_PATH = "/api/auth/oauth2/register"; -export const OAUTH_TOKEN_PATH = "/api/auth/oauth2/token"; -export const OAUTH_OPENID_CONFIGURATION_PATH = - "/.well-known/openid-configuration"; +const MCP_ROUTE_PATH = "/api/mcp"; function uniqueUrls(values: string[]): string[] { return [...new Set(values.map((value) => value.replace(/\/$/, "")))]; @@ -18,8 +13,12 @@ function normalizePathname(pathname: string): string { return pathname.endsWith("/") ? pathname.slice(0, -1) : pathname; } +function toUrl(target: URL | string): URL { + return new URL(target.toString()); +} + function insertWellKnownPath(target: URL | string, suffix: string): URL { - const url = new URL(target.toString()); + const url = toUrl(target); const normalizedPathname = normalizePathname(url.pathname); return new URL( `/.well-known/${suffix}${normalizedPathname}${url.search}`, @@ -28,7 +27,7 @@ function insertWellKnownPath(target: URL | string, suffix: string): URL { } function appendWellKnownPath(target: URL | string, suffix: string): URL { - const url = new URL(target.toString()); + const url = toUrl(target); const normalizedPathname = normalizePathname(url.pathname); return new URL( `${normalizedPathname}/.well-known/${suffix}${url.search}`, @@ -49,11 +48,11 @@ export function getCanonicalOAuthIssuer(): string { } export function getOAuthTokenVerificationIssuers(): string[] { - return uniqueUrls([getCanonicalOAuthIssuer(), getPublicOrigin()]); + return [getCanonicalOAuthIssuer()]; } export function getOAuthRestAudienceUrls(): string[] { - return uniqueUrls([getPublicOrigin(), getCanonicalOAuthIssuer()]); + return [getCanonicalOAuthIssuer()]; } export function getOAuthMcpAudienceUrls(): string[] { @@ -73,30 +72,22 @@ export function getJwksUrlForOAuthVerification(): string { return new URL("/api/auth/jwks", `${getCanonicalOAuthIssuer()}/`).toString(); } -export function getMcpServerUrl(_request?: Request): URL { +export function getMcpServerUrl(): URL { return new URL(getOAuthMcpResourceUrl()); } -export function getOAuthIssuerUrl(_request?: Request): URL { +export function getOAuthIssuerUrl(): URL { return new URL(getCanonicalOAuthIssuer()); } -export function getOAuthAuthorizationServerMetadataUrl( - _request?: Request, -): URL { +export function getOAuthAuthorizationServerMetadataUrl(): URL { return insertWellKnownPath(getOAuthIssuerUrl(), "oauth-authorization-server"); } -export function getOAuthProtectedResourceMetadataUrl(_request?: Request): URL { +export function getOAuthProtectedResourceMetadataUrl(): URL { return insertWellKnownPath(getMcpServerUrl(), "oauth-protected-resource"); } -export function getOAuthOpenIdConfigurationUrl(_request?: Request): URL { +export function getOAuthOpenIdConfigurationUrl(): URL { return appendWellKnownPath(getOAuthIssuerUrl(), "openid-configuration"); } - -export function getOAuthOpenIdConfigurationCompatibilityUrl( - _request?: Request, -): URL { - return insertWellKnownPath(getOAuthIssuerUrl(), "openid-configuration"); -} diff --git a/tests/e2e/src/app/api/mcp/test.ts b/tests/e2e/src/app/api/mcp/test.ts index a707b390..7892aadb 100644 --- a/tests/e2e/src/app/api/mcp/test.ts +++ b/tests/e2e/src/app/api/mcp/test.ts @@ -21,6 +21,13 @@ import { createHash } from "node:crypto"; import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; import { expect, type Page, test } from "@playwright/test"; +import { + DEFAULT_OAUTH_CLIENT_SCOPES, + MCP_TOOLS_SCOPE, + OAUTH_AUTHORIZATION_CODE_GRANT_TYPE, + OAUTH_CODE_RESPONSE_TYPE, + OAUTH_PUBLIC_CLIENT_AUTH_METHOD, +} from "@/lib/oauth/constants"; import { signInAsDebugUser } from "../../../../utils/auth"; import { DEV_SEED, DEV_SEED_ANCHOR } from "../../../../utils/dev-seed"; import { @@ -37,6 +44,9 @@ function generateCodeChallenge(codeVerifier: string) { } const REDIRECT_URI = `${PLAYWRIGHT_BASE_URL}/e2e/oauth/callback`; +const MCP_CLIENT_SCOPES = [...DEFAULT_OAUTH_CLIENT_SCOPES, MCP_TOOLS_SCOPE]; +const MCP_CLIENT_SCOPE = MCP_CLIENT_SCOPES.join(" "); +const DEFAULT_CLIENT_SCOPE = DEFAULT_OAUTH_CLIENT_SCOPES.join(" "); const TRUSTED_BROWSER_ORIGIN = PLAYWRIGHT_BASE_URL.includes("127.0.0.1") ? PLAYWRIGHT_BASE_URL.replace("127.0.0.1", "localhost") : PLAYWRIGHT_BASE_URL.replace("localhost", "127.0.0.1"); @@ -78,9 +88,9 @@ async function registerPublicClient(request: Page["request"], scope: string) { data: { client_name: `mcp-e2e-${Date.now()}`, redirect_uris: [REDIRECT_URI], - token_endpoint_auth_method: "none", - grant_types: ["authorization_code"], - response_types: ["code"], + token_endpoint_auth_method: OAUTH_PUBLIC_CLIENT_AUTH_METHOD, + grant_types: [OAUTH_AUTHORIZATION_CODE_GRANT_TYPE], + response_types: [OAUTH_CODE_RESPONSE_TYPE], scope, }, }); @@ -103,7 +113,7 @@ async function authorizeAndGetCode( "/api/auth/oauth2/authorize", { params: { - response_type: "code", + response_type: OAUTH_CODE_RESPONSE_TYPE, client_id: clientId, redirect_uri: REDIRECT_URI, scope: options.scope, @@ -168,7 +178,7 @@ async function issueAccessToken( const tokenResponse = await request.post("/api/auth/oauth2/token", { form: { - grant_type: "authorization_code", + grant_type: OAUTH_AUTHORIZATION_CODE_GRANT_TYPE, client_id: clientId, code, code_verifier: codeVerifier, @@ -325,8 +335,8 @@ test.describe("/api/mcp – MCP Streamable-HTTP transport", () => { const resource = `${PLAYWRIGHT_BASE_URL}/api/mcp`; await signInAsDebugUser(page, "/"); const { accessToken } = await issueAccessToken(page, request, { - scope: "openid profile mcp:tools", - clientScopes: ["openid", "profile", "mcp:tools"], + scope: MCP_CLIENT_SCOPE, + clientScopes: MCP_CLIENT_SCOPES, resource, }); @@ -362,8 +372,8 @@ test.describe("/api/mcp – MCP Streamable-HTTP transport", () => { await signInAsDebugUser(page, "/"); const { accessToken } = await issueAccessToken(page, request, { - scope: "openid profile mcp:tools", - clientScopes: ["openid", "profile", "mcp:tools"], + scope: MCP_CLIENT_SCOPE, + clientScopes: MCP_CLIENT_SCOPES, resource, }); @@ -404,8 +414,8 @@ test.describe("/api/mcp – MCP Streamable-HTTP transport", () => { await signInAsDebugUser(page, "/"); const { accessToken } = await issueAccessToken(page, request, { - scope: "openid profile", - clientScopes: ["openid", "profile"], + scope: DEFAULT_CLIENT_SCOPE, + clientScopes: [...DEFAULT_OAUTH_CLIENT_SCOPES], resource, }); @@ -434,7 +444,7 @@ test.describe("/api/mcp – MCP Streamable-HTTP transport", () => { 'error="insufficient_scope"', ); expect(response.headers()["www-authenticate"]).toContain( - 'scope="mcp:tools"', + `scope="${MCP_TOOLS_SCOPE}"`, ); await expect(response.json()).resolves.toEqual({ error: "insufficient_scope", @@ -449,8 +459,8 @@ test.describe("/api/mcp – MCP Streamable-HTTP transport", () => { await signInAsDebugUser(page, "/"); const { accessToken } = await issueAccessToken(page, request, { - scope: "openid profile mcp:tools", - clientScopes: ["openid", "profile", "mcp:tools"], + scope: MCP_CLIENT_SCOPE, + clientScopes: MCP_CLIENT_SCOPES, resource, includeResourceInTokenExchange: false, }); @@ -496,8 +506,8 @@ test.describe("/api/mcp – MCP Streamable-HTTP transport", () => { await signInAsDebugUser(page, "/"); const { accessToken } = await issueAccessToken(page, request, { - scope: "openid profile mcp:tools", - clientScopes: ["openid", "profile", "mcp:tools"], + scope: MCP_CLIENT_SCOPE, + clientScopes: MCP_CLIENT_SCOPES, resource, includeResourceInTokenExchange: false, }); @@ -529,8 +539,8 @@ test.describe("/api/mcp – MCP Streamable-HTTP transport", () => { }, }); const { accessToken } = await issueAccessToken(page, request, { - scope: "openid profile mcp:tools", - clientScopes: ["openid", "profile", "mcp:tools"], + scope: MCP_CLIENT_SCOPE, + clientScopes: MCP_CLIENT_SCOPES, resource, }); @@ -600,7 +610,6 @@ test.describe("/api/mcp – MCP Streamable-HTTP transport", () => { "get_upcoming_deadlines", "list_my_homeworks", "set_my_homework_completion", - "unset_my_homework_completion", "list_my_schedules", "list_my_exams", "get_my_overview", @@ -913,37 +922,6 @@ test.describe("/api/mcp – MCP Streamable-HTTP transport", () => { expect(setCompletionFalsePayload.success).toBe(true); expect(setCompletionFalsePayload.completion?.completed).toBe(false); - // Re-mark as completed so unset_my_homework_completion has something to undo - const reSetCompletionResult = await mcpClient.callTool({ - name: "set_my_homework_completion", - arguments: { - homeworkId: firstHomeworkId, - completed: true, - }, - }); - expect( - ( - parseTextContent(reSetCompletionResult) as { - success?: boolean; - } - ).success, - ).toBe(true); - - const unsetCompletionResult = await mcpClient.callTool({ - name: "unset_my_homework_completion", - arguments: { - homeworkId: firstHomeworkId, - }, - }); - const unsetCompletionPayload = parseTextContent( - unsetCompletionResult, - ) as { - success?: boolean; - completion?: { completed?: boolean; completedAt?: null }; - }; - expect(unsetCompletionPayload.success).toBe(true); - expect(unsetCompletionPayload.completion?.completed).toBe(false); - const mySchedulesResult = await mcpClient.callTool({ name: "list_my_schedules", arguments: { @@ -1307,6 +1285,17 @@ test.describe("/api/mcp – MCP Streamable-HTTP transport", () => { } | undefined; await expect(async () => { + const preferenceResponse = await page.request.post( + "/api/bus/preferences", + { + data: { + preferredOriginCampusId: 1, + preferredDestinationCampusId: 4, + showDepartedTrips: true, + }, + }, + ); + expect(preferenceResponse.status()).toBe(200); busResult = await mcpClient.callTool({ name: "query_bus_timetable", arguments: { diff --git a/tests/integration/mcp-tools.test.ts b/tests/integration/mcp-tools.test.ts index da832e3d..d476fc3b 100644 --- a/tests/integration/mcp-tools.test.ts +++ b/tests/integration/mcp-tools.test.ts @@ -669,17 +669,15 @@ describe("list_schedules_by_section — date range filter", () => { it("returns error message for invalid dateFrom", async () => { const result = await mcp.call<{ - found?: boolean; + success?: boolean; message?: string; - schedules?: unknown[]; }>("list_schedules_by_section", { sectionJwId: DEV_SEED.section.jwId, dateFrom: "yesterday", locale: "zh-cn", }); - expect(result.found).toBe(true); - expect(result.schedules).toHaveLength(0); + expect(result.success).toBe(false); expect(result.message).toContain("yesterday"); }); }); @@ -721,6 +719,23 @@ describe("query_schedules — flexible date filters", () => { }); }); +describe("course and section lookup errors", () => { + it("get_section_by_jw_id returns a recovery hint when the jwId is missing", async () => { + const result = await mcp.call<{ + found?: boolean; + message?: string; + hint?: string; + }>("get_section_by_jw_id", { + jwId: 999999999, + locale: "zh-cn", + }); + + expect(result.found).toBe(false); + expect(result.message).toContain("999999999"); + expect(result.hint).toContain("search_sections"); + }); +}); + // --------------------------------------------------------------------------- // Dashboard snapshot — compact shape verification // --------------------------------------------------------------------------- @@ -805,6 +820,24 @@ describe("get_next_buses — default mode drops repeated campus objects", () => expect(result.totalRoutes).toBeGreaterThan(0); }); + it("rejects invalid atTime with the shared MCP date message", async () => { + const result = await mcp.call<{ success?: boolean; message?: string }>( + "get_next_buses", + { + locale: "zh-cn", + originCampusId: DEV_SEED.bus.originCampusId, + destinationCampusId: DEV_SEED.bus.destinationCampusId, + atTime: "not-a-date", + }, + ); + + expect(result).toMatchObject({ + success: false, + message: + 'Invalid atTime: "not-a-date". Use YYYY-MM-DD or YYYY-MM-DDTHH:MM:SS+08:00.', + }); + }); + it("departure items omit originCampus and destinationCampus", async () => { const result = await mcp.call<{ originCampus?: { id?: number }; diff --git a/tests/integration/utils/mcp-harness.ts b/tests/integration/utils/mcp-harness.ts index 6eb46b4d..549db512 100644 --- a/tests/integration/utils/mcp-harness.ts +++ b/tests/integration/utils/mcp-harness.ts @@ -28,6 +28,10 @@ import type { } from "@modelcontextprotocol/sdk/shared/transport.js"; import type { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js"; import { createMcpServer } from "@/lib/mcp/server"; +import { + DEFAULT_OAUTH_CLIENT_SCOPES, + MCP_TOOLS_SCOPE, +} from "@/lib/oauth/constants"; /** * Build a minimal AuthInfo that makes tool handlers believe @@ -37,7 +41,7 @@ export function makeTestAuthInfo(userId: string): AuthInfo { return { token: "integration-test-token", clientId: "integration-test-client", - scopes: ["openid", "profile", "mcp:tools"], + scopes: [...DEFAULT_OAUTH_CLIENT_SCOPES, MCP_TOOLS_SCOPE], extra: { userId }, }; } diff --git a/tests/unit/compact-payload.test.ts b/tests/unit/compact-payload.test.ts index 23f32e37..d9008aca 100644 --- a/tests/unit/compact-payload.test.ts +++ b/tests/unit/compact-payload.test.ts @@ -39,6 +39,31 @@ describe("compactMcpPayload", () => { (result[0].todos as Record[])[0], ).not.toHaveProperty("extra"); }); + + it("compacts top-level arrays of known records", () => { + const input = [ + { + id: "c1", + jwId: "J1", + code: "CS101", + namePrimary: "Intro CS", + credit: 3, + hours: 48, + description: "removed", + }, + ]; + + const result = compactMcpPayload(input) as Record[]; + + expect(result[0]).toEqual({ + id: "c1", + jwId: "J1", + code: "CS101", + namePrimary: "Intro CS", + credit: 3, + hours: 48, + }); + }); }); describe("todos", () => { @@ -460,7 +485,7 @@ describe("compactMcpPayload", () => { expect(events[0]).toEqual({ type: "schedule", at: "2024-01-01" }); }); - it("compacts generic event payloads that look like schedules", () => { + it("does not compact generic payloads by structural inference", () => { const input = { nextClass: { type: "schedule", @@ -487,8 +512,8 @@ describe("compactMcpPayload", () => { const result = compactMcpPayload(input) as Record; const nextClass = result.nextClass as Record; const payload = nextClass.payload as Record; - expect(payload).not.toHaveProperty("scheduleGroup"); - expect(payload).not.toHaveProperty("roomType"); + expect(payload).toHaveProperty("scheduleGroup"); + expect(payload).toHaveProperty("roomType"); expect(payload.room).toEqual({ id: "r1", jwId: "RJ1", diff --git a/tests/unit/mcp-auth.test.ts b/tests/unit/mcp-auth.test.ts index e5e0b719..f2892ddc 100644 --- a/tests/unit/mcp-auth.test.ts +++ b/tests/unit/mcp-auth.test.ts @@ -1,4 +1,5 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; +import { MCP_TOOLS_SCOPE } from "@/lib/oauth/constants"; const verifyOAuthAccessTokenMock = vi.fn(); @@ -31,10 +32,7 @@ vi.mock("@/lib/mcp/urls", () => ({ new URL( "https://life.example/.well-known/oauth-protected-resource/api/mcp", ), - getOAuthTokenVerificationIssuers: () => [ - "https://life.example/api/auth", - "https://life.example", - ], + getOAuthTokenVerificationIssuers: () => ["https://life.example/api/auth"], })); describe("MCP auth", () => { @@ -43,12 +41,12 @@ describe("MCP auth", () => { verifyOAuthAccessTokenMock.mockReset(); }); - it("accepts canonical and legacy OAuth issuers for JWT access tokens", async () => { + it("verifies JWT access tokens against the canonical OAuth issuer", async () => { verifyOAuthAccessTokenMock.mockResolvedValue({ azp: "client-id", aud: "https://life.example/api/mcp", exp: 1_900_000_000, - scope: "mcp:tools", + scope: MCP_TOOLS_SCOPE, sub: "user-id", }); const { verifyAccessToken } = await import("@/lib/mcp/auth"); @@ -63,7 +61,7 @@ describe("MCP auth", () => { expect.objectContaining({ jwksUrl: "https://life.example/api/auth/jwks", verifyOptions: { - issuer: ["https://life.example/api/auth", "https://life.example"], + issuer: ["https://life.example/api/auth"], audience: [ "https://life.example/api/mcp", "https://life.example/api/auth/oauth2/userinfo", @@ -74,7 +72,7 @@ describe("MCP auth", () => { ); expect(authInfo).toMatchObject({ clientId: "client-id", - scopes: ["mcp:tools"], + scopes: [MCP_TOOLS_SCOPE], extra: { userId: "user-id" }, }); }); diff --git a/tests/unit/mcp-tool-helpers.test.ts b/tests/unit/mcp-tool-helpers.test.ts new file mode 100644 index 00000000..5356eb85 --- /dev/null +++ b/tests/unit/mcp-tool-helpers.test.ts @@ -0,0 +1,83 @@ +import { describe, expect, it, vi } from "vitest"; + +vi.mock("@/lib/db/prisma", () => ({ + getPrisma: vi.fn(), + prisma: {}, +})); + +import { jsonToolResult } from "@/lib/mcp/tools/_helpers"; + +function parseToolText(result: ReturnType) { + const text = result.content.find( + (item): item is { type: "text"; text: string } => + item.type === "text" && typeof item.text === "string", + )?.text; + + expect(text).toBeDefined(); + return JSON.parse(text ?? "{}") as Record; +} + +describe("jsonToolResult summary mode", () => { + it("keeps paginated totals while reporting returned and sampled items", () => { + const result = parseToolText( + jsonToolResult( + { + data: Array.from({ length: 12 }, (_, index) => ({ + id: index + 1, + title: `Item ${index + 1}`, + })), + pagination: { + page: 2, + pageSize: 12, + total: 53, + totalPages: 5, + }, + }, + { mode: "summary" }, + ), + ); + + expect(result.pagination).toEqual({ + page: 2, + pageSize: 12, + total: 53, + totalPages: 5, + }); + expect(result.data).toEqual({ + total: 53, + returned: 12, + remaining: 2, + truncated: true, + items: Array.from({ length: 10 }, (_, index) => ({ + id: index + 1, + title: `Item ${index + 1}`, + })), + }); + }); + + it("reports non-paginated list truncation metadata", () => { + const result = parseToolText( + jsonToolResult( + { + homeworks: Array.from({ length: 3 }, (_, index) => ({ + id: `hw-${index + 1}`, + title: `Homework ${index + 1}`, + })), + }, + { mode: "summary" }, + ), + ); + + expect(result.homeworks).toEqual({ + total: 3, + returned: 3, + remaining: 0, + truncated: false, + items: [ + { id: "hw-1", title: "Homework 1" }, + { id: "hw-2", title: "Homework 2" }, + { id: "hw-3", title: "Homework 3" }, + ], + }); + }); +}); diff --git a/tests/unit/mcp-urls.test.ts b/tests/unit/mcp-urls.test.ts index a782f93f..0dfa5224 100644 --- a/tests/unit/mcp-urls.test.ts +++ b/tests/unit/mcp-urls.test.ts @@ -4,7 +4,6 @@ import { getOAuthAuthorizationServerMetadataUrl, getOAuthIssuerUrl, getOAuthMcpAudienceUrls, - getOAuthOpenIdConfigurationCompatibilityUrl, getOAuthOpenIdConfigurationUrl, getOAuthProtectedResourceMetadataUrl, getOAuthProviderValidAudiences, @@ -29,18 +28,10 @@ describe("MCP URL helpers", () => { it("falls back to VERCEL_URL when APP_PUBLIC_ORIGIN is unset", () => { vi.stubEnv("APP_PUBLIC_ORIGIN", ""); - vi.stubEnv("BETTER_AUTH_URL", ""); vi.stubEnv("VERCEL_URL", "life-preview.vercel.app"); expect(getPublicOrigin()).toBe("https://life-preview.vercel.app"); }); - it("falls back to BETTER_AUTH_URL when APP_PUBLIC_ORIGIN is unset", () => { - vi.stubEnv("APP_PUBLIC_ORIGIN", ""); - vi.stubEnv("BETTER_AUTH_URL", "https://legacy.example.com"); - expect(getPublicOrigin()).toBe("https://legacy.example.com"); - expect(getBetterAuthBaseUrl()).toBe("https://legacy.example.com/api/auth"); - }); - it("falls back to VERCEL_PROJECT_PRODUCTION_URL for canonical origin", () => { vi.stubEnv("APP_CANONICAL_ORIGIN", ""); vi.stubEnv("VERCEL_PROJECT_PRODUCTION_URL", "life-ustc.tiankaima.dev"); @@ -51,11 +42,9 @@ describe("MCP URL helpers", () => { vi.stubEnv("APP_PUBLIC_ORIGIN", "https://life.example.com"); expect(getCanonicalOAuthIssuer()).toBe("https://life.example.com/api/auth"); expect(getOAuthRestAudienceUrls()).toEqual([ - "https://life.example.com", "https://life.example.com/api/auth", ]); expect(getOAuthProviderValidAudiences()).toEqual([ - "https://life.example.com", "https://life.example.com/api/auth", "https://life.example.com/api/mcp", ]); @@ -73,9 +62,6 @@ describe("MCP URL helpers", () => { expect(getOAuthOpenIdConfigurationUrl().toString()).toBe( "https://life.example.com/api/auth/.well-known/openid-configuration", ); - expect(getOAuthOpenIdConfigurationCompatibilityUrl().toString()).toBe( - "https://life.example.com/.well-known/openid-configuration/api/auth", - ); expect(getOAuthProtectedResourceMetadataUrl().toString()).toBe( "https://life.example.com/.well-known/oauth-protected-resource/api/mcp", ); From 335d174f21f3467930b8078733c4dc4445cc579b Mon Sep 17 00:00:00 2001 From: Tiankai Ma Date: Thu, 28 May 2026 13:31:56 +0800 Subject: [PATCH 4/9] refactor(bus): streamline planner data and admin flows --- src/app/admin/bus/bus-version-manager.tsx | 269 ++++-- .../bus/components/bus-panel-shared.ts | 2 +- src/features/bus/components/bus-panel.tsx | 131 ++- .../bus/components/bus-planner-controls.tsx | 138 ++- .../bus/components/bus-route-table.tsx | 304 +++++-- .../bus/components/bus-transit-map-layout.ts | 26 +- .../bus/components/bus-transit-map.tsx | 34 +- src/features/bus/lib/bus-catalog.ts | 114 +++ src/features/bus/lib/bus-departures.ts | 321 +++++++ src/features/bus/lib/bus-import.ts | 14 +- src/features/bus/lib/bus-route-builder.ts | 109 +++ src/features/bus/lib/bus-service.ts | 813 +----------------- src/features/bus/lib/bus-time.ts | 14 + src/features/bus/lib/bus-transit-map.ts | 127 +++ src/features/bus/lib/bus-version.ts | 94 ++ tests/e2e/src/app/admin/bus/test.ts | 11 +- tests/e2e/src/app/bus/test.ts | 21 +- tests/unit/bus-client.test.ts | 6 +- tests/unit/bus-service.test.ts | 18 + tests/unit/bus-static-source.test.ts | 2 +- 20 files changed, 1498 insertions(+), 1070 deletions(-) create mode 100644 src/features/bus/lib/bus-catalog.ts create mode 100644 src/features/bus/lib/bus-departures.ts create mode 100644 src/features/bus/lib/bus-route-builder.ts create mode 100644 src/features/bus/lib/bus-time.ts create mode 100644 src/features/bus/lib/bus-transit-map.ts create mode 100644 src/features/bus/lib/bus-version.ts create mode 100644 tests/unit/bus-service.test.ts diff --git a/src/app/admin/bus/bus-version-manager.tsx b/src/app/admin/bus/bus-version-manager.tsx index 3d2f0e46..07628e4e 100644 --- a/src/app/admin/bus/bus-version-manager.tsx +++ b/src/app/admin/bus/bus-version-manager.tsx @@ -102,9 +102,14 @@ export function BusVersionManager({ versions }: { versions: VersionRow[] }) { return (
-
+

{t("versionsTitle")}

- + )} + {!v.isEnabled && ( + + )} +
+ + + ))} + + +
+ )} | null>(null); const abortRef = useRef(null); + const saveGenerationRef = useRef(0); useEffect(() => { setSelectedDayType(resolveClientBusDayType(new Date())); @@ -91,6 +89,14 @@ export function BusPanel({ }), [data, selectedDayType, startCampusId, endCampusId, showDepartedTrips, now], ); + const startCampus = useMemo( + () => data.campuses.find((campus) => campus.id === startCampusId) ?? null, + [data.campuses, startCampusId], + ); + const endCampus = useMemo( + () => data.campuses.find((campus) => campus.id === endCampusId) ?? null, + [data.campuses, endCampusId], + ); const showPlannerEstimatedHint = useMemo(() => { const inVisibleRows = applicableRoutes.some((route) => @@ -115,6 +121,7 @@ export function BusPanel({ abortRef.current?.abort(); const controller = new AbortController(); abortRef.current = controller; + const saveGeneration = saveGenerationRef.current; setSaveState("saving"); setSaveError(null); @@ -140,7 +147,10 @@ export function BusPanel({ body = null; } + if (saveGeneration !== saveGenerationRef.current) return; + if (!response.ok) { + dirtyRef.current = true; setSaveState("error"); setSaveError( extractApiErrorMessage(body) ?? t("preferences.saveFailed"), @@ -148,9 +158,12 @@ export function BusPanel({ return; } + dirtyRef.current = false; setSaveState("saved"); } catch (error) { if ((error as Error).name === "AbortError") return; + if (saveGeneration !== saveGenerationRef.current) return; + dirtyRef.current = true; setSaveState("error"); setSaveError(t("preferences.saveFailed")); } @@ -168,7 +181,6 @@ export function BusPanel({ timerRef.current = setTimeout(() => { savePreference(startCampusId, endCampusId, showDepartedTrips); - dirtyRef.current = false; }, AUTO_SAVE_DELAY_MS); return () => { @@ -188,6 +200,7 @@ export function BusPanel({ const markDirty = useCallback(() => { if (!signedIn || !showPreferences) return; dirtyRef.current = true; + saveGenerationRef.current += 1; setSaveState("idle"); setSaveError(null); }, [showPreferences, signedIn]); @@ -205,91 +218,53 @@ export function BusPanel({ ? t("preferences.saved") : saveState === "error" ? (saveError ?? t("preferences.saveFailed")) - : showPreferences && signedIn - ? t("preferences.autosaveHint") - : t("planner.clientHint"); - - const plannerActions = ( - <> - -