diff --git a/.gitignore b/.gitignore index 5eafdc7..e223391 100644 --- a/.gitignore +++ b/.gitignore @@ -32,4 +32,6 @@ Thumbs.db sample-readme.md sample_tests.md sample_setup.md -docs/DASHBOARD_GUIDE.md \ No newline at end of file +docs/DASHBOARD_GUIDE.md + +scripts/ \ No newline at end of file diff --git a/app/config.py b/app/config.py index 507d92f..ff4316c 100644 --- a/app/config.py +++ b/app/config.py @@ -39,6 +39,8 @@ class Settings(BaseSettings): dashboard_cache_ttl_seconds: int = 60 recycle_bin_retention_days: int = 30 + max_csv_import_bytes: int = 5 * 1024 * 1024 + jwt_clock_skew_seconds: int = 10 @computed_field @property diff --git a/app/main.py b/app/main.py index eae1b32..5536317 100644 --- a/app/main.py +++ b/app/main.py @@ -1,4 +1,5 @@ from contextlib import asynccontextmanager +import logging from fastapi import FastAPI, WebSocket @@ -8,10 +9,13 @@ from app.services.presence_service import handle_presence_socket settings = get_settings() +logger = logging.getLogger(__name__) @asynccontextmanager async def lifespan(_: FastAPI): + if settings.secret_key == "replace-with-strong-secret": + logger.warning("Using default SECRET_KEY placeholder. Set SECRET_KEY for non-local environments.") yield await close_connections() diff --git a/app/oauth2.py b/app/oauth2.py index 9979660..7628578 100644 --- a/app/oauth2.py +++ b/app/oauth2.py @@ -62,6 +62,7 @@ async def decode_token(token: str) -> TokenPayload: token, settings.secret_key, [settings.algorithm], + {"leeway": settings.jwt_clock_skew_seconds}, ) return TokenPayload(**payload) except JWTError as exc: diff --git a/app/routes/financial_records.py b/app/routes/financial_records.py index 9aa24dd..b26a48a 100644 --- a/app/routes/financial_records.py +++ b/app/routes/financial_records.py @@ -3,6 +3,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response, status from redis.asyncio import Redis from sqlalchemy import delete, func, or_, select +from sqlalchemy.exc import IntegrityError, SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession from app.config import get_settings @@ -53,6 +54,17 @@ async def _purge_expired_deleted_records(db: AsyncSession) -> None: await db.commit() +async def _safe_commit(db: AsyncSession, *, integrity_detail: str, generic_detail: str) -> None: + try: + await db.commit() + except IntegrityError as exc: + await db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=integrity_detail) from exc + except SQLAlchemyError as exc: + await db.rollback() + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=generic_detail) from exc + + @router.post("", response_model=FinancialRecordOut, status_code=status.HTTP_201_CREATED) async def create_financial_record( payload: FinancialRecordCreate, @@ -78,7 +90,11 @@ async def create_financial_record( user_id=target_user_id, ) db.add(row) - await db.commit() + await _safe_commit( + db, + integrity_detail="Unable to create record with provided data", + generic_detail="Failed to create financial record", + ) await db.refresh(row) await invalidate_dashboard_cache(redis) return FinancialRecordOut.model_validate(row) @@ -208,6 +224,8 @@ async def import_financial_records_csv( raw_body = await request.body() if not raw_body: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="CSV body is empty") + if len(raw_body) > get_settings().max_csv_import_bytes: + raise HTTPException(status_code=status.HTTP_413_CONTENT_TOO_LARGE, detail="CSV body too large") try: csv_text = raw_body.decode("utf-8") @@ -254,7 +272,11 @@ async def import_financial_records_csv( for _, payload, _ in valid_rows ] db.add_all(created_rows) - await db.commit() + await _safe_commit( + db, + integrity_detail="CSV contains rows that violate data constraints", + generic_detail="Failed to import CSV records", + ) imported_count = len(created_rows) await invalidate_dashboard_cache(redis) @@ -326,7 +348,11 @@ async def update_financial_record( else: setattr(row, key, value) - await db.commit() + await _safe_commit( + db, + integrity_detail="Unable to update record with provided data", + generic_detail="Failed to update financial record", + ) await db.refresh(row) await invalidate_dashboard_cache(redis) return FinancialRecordOut.model_validate(row) diff --git a/app/routes/users.py b/app/routes/users.py index e65aa38..65a4538 100644 --- a/app/routes/users.py +++ b/app/routes/users.py @@ -228,6 +228,7 @@ async def update_user_by_admin( payload: UserUpdate, db: AsyncSession = Depends(get_db), _: User = Depends(require_roles(UserRole.ADMIN)), + __: None = Depends(sensitive_route_limiter), ) -> UserOut: if payload.role is None and payload.is_active is None: raise HTTPException( diff --git a/app/services/auth_service.py b/app/services/auth_service.py index 9085649..b8fa26c 100644 --- a/app/services/auth_service.py +++ b/app/services/auth_service.py @@ -1,7 +1,10 @@ from datetime import datetime, timezone +import logging +from sqlalchemy.dialects.postgresql import insert from redis.asyncio import Redis from redis.exceptions import RedisError +from sqlalchemy.exc import SQLAlchemyError from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -15,6 +18,8 @@ ) from app.schemas import Token +logger = logging.getLogger(__name__) + async def authenticate_user(db: AsyncSession, email: str, password: str) -> User | None: user = await db.scalar(select(User).where(User.email == email)) @@ -50,27 +55,27 @@ async def issue_token_pair(db: AsyncSession, user_id: int) -> Token: async def _blacklist_token( db: AsyncSession, - redis: Redis, + redis: Redis | None, jti: str, token_type: str, expires_at: datetime, + persist_db: bool = True, ) -> None: - existing = await db.scalar(select(TokenBlacklist).where(TokenBlacklist.jti == jti)) - if existing is None: - db.add( - TokenBlacklist( - jti=jti, - token_type=token_type, - expires_at=expires_at, - ) + if persist_db: + stmt = insert(TokenBlacklist).values( + jti=jti, + token_type=token_type, + expires_at=expires_at, ) + stmt = stmt.on_conflict_do_nothing(index_elements=[TokenBlacklist.jti]) + await db.execute(stmt) ttl_seconds = int((expires_at - datetime.now(timezone.utc)).total_seconds()) - if ttl_seconds > 0: + if ttl_seconds > 0 and redis is not None: try: await redis.set(f"blacklist:{jti}", "1", ex=ttl_seconds) - except RedisError: - pass + except RedisError as exc: + logger.warning("Failed to write token blacklist to Redis for jti=%s: %s", jti, exc) async def refresh_access_pair(db: AsyncSession, redis: Redis, refresh_token: str) -> Token: @@ -106,14 +111,27 @@ async def refresh_access_pair(db: AsyncSession, redis: Redis, refresh_token: str await _blacklist_token( db=db, - redis=redis, + redis=None, jti=payload.jti, token_type="refresh", expires_at=exp_to_datetime(payload.exp), ) await _store_refresh_token(db, user.id, new_refresh_token) - await db.commit() + try: + await db.commit() + except SQLAlchemyError as exc: + await db.rollback() + raise ValueError("Unable to refresh token at the moment") from exc + + await _blacklist_token( + db=db, + redis=redis, + jti=payload.jti, + token_type="refresh", + expires_at=exp_to_datetime(payload.exp), + persist_db=False, + ) return Token(access_token=new_access_token, refresh_token=new_refresh_token) diff --git a/app/services/dashboard_service.py b/app/services/dashboard_service.py index 31b331e..219d928 100644 --- a/app/services/dashboard_service.py +++ b/app/services/dashboard_service.py @@ -1,8 +1,10 @@ import json +import logging from datetime import date from decimal import Decimal from redis.asyncio import Redis +from redis.exceptions import RedisError from sqlalchemy import case, func, select from sqlalchemy.ext.asyncio import AsyncSession @@ -11,6 +13,7 @@ from app.schemas import CategoryTotal, MonthlyTrendPoint, RecentActivityItem, SummaryTotals settings = get_settings() +logger = logging.getLogger(__name__) def _decimal_to_str(value: Decimal) -> str: @@ -37,7 +40,8 @@ async def invalidate_dashboard_cache(redis: Redis) -> None: keys = await redis.keys("dashboard:*") if keys: await redis.delete(*keys) - except Exception: + except RedisError as exc: + logger.warning("Failed to invalidate dashboard cache: %s", exc) return @@ -50,7 +54,8 @@ async def get_summary_totals( key = _cache_key("totals", start_date, end_date) try: cached = await redis.get(key) - except Exception: + except RedisError as exc: + logger.warning("Failed to read dashboard cache key %s: %s", key, exc) cached = None if cached: payload = json.loads(cached) @@ -94,8 +99,8 @@ async def get_summary_totals( ), ex=settings.dashboard_cache_ttl_seconds, ) - except Exception: - pass + except RedisError as exc: + logger.warning("Failed to write dashboard cache key %s: %s", key, exc) return result @@ -108,7 +113,8 @@ async def get_category_totals( key = _cache_key("categories", start_date, end_date) try: cached = await redis.get(key) - except Exception: + except RedisError as exc: + logger.warning("Failed to read dashboard cache key %s: %s", key, exc) cached = None if cached: payload = json.loads(cached) @@ -130,8 +136,8 @@ async def get_category_totals( ), ex=settings.dashboard_cache_ttl_seconds, ) - except Exception: - pass + except RedisError as exc: + logger.warning("Failed to write dashboard cache key %s: %s", key, exc) return result @@ -169,7 +175,8 @@ async def get_monthly_trends( key = _cache_key("monthly_trends", start_date, end_date) try: cached = await redis.get(key) - except Exception: + except RedisError as exc: + logger.warning("Failed to read dashboard cache key %s: %s", key, exc) cached = None if cached: payload = json.loads(cached) @@ -225,6 +232,6 @@ async def get_monthly_trends( ), ex=settings.dashboard_cache_ttl_seconds, ) - except Exception: - pass + except RedisError as exc: + logger.warning("Failed to write dashboard cache key %s: %s", key, exc) return result diff --git a/app/services/presence_service.py b/app/services/presence_service.py index 537c6d6..8a73dc0 100644 --- a/app/services/presence_service.py +++ b/app/services/presence_service.py @@ -1,6 +1,7 @@ import asyncio +import logging -from fastapi import WebSocket +from fastapi import WebSocket, WebSocketDisconnect from jose import JWTError from redis.asyncio import Redis from sqlalchemy.ext.asyncio import AsyncSession @@ -12,6 +13,7 @@ from app.oauth2 import decode_token settings = get_settings() +logger = logging.getLogger(__name__) def _presence_key(user_id: int) -> str: @@ -78,7 +80,8 @@ async def handle_presence_socket(websocket: WebSocket, redis: Redis) -> None: try: user_id = await _authenticate_socket_user(websocket, redis) - except ValueError: + except ValueError as exc: + logger.debug("Presence websocket authentication failed: %s", exc) return await set_online(redis, user_id) @@ -96,7 +99,12 @@ async def handle_presence_socket(websocket: WebSocket, redis: Redis) -> None: except asyncio.TimeoutError: await websocket.send_json({"type": "ping"}) await set_online(redis, user_id) - except Exception: - pass + except WebSocketDisconnect: + logger.debug("Presence websocket disconnected for user_id=%s", user_id) + except Exception as exc: + logger.warning("Presence websocket loop error for user_id=%s: %s", user_id, exc) finally: - await set_offline(redis, user_id) + try: + await set_offline(redis, user_id) + except Exception as exc: + logger.warning("Failed to clear presence state for user_id=%s: %s", user_id, exc) diff --git a/app/services/rate_limiter.py b/app/services/rate_limiter.py index bc568a6..19f7438 100644 --- a/app/services/rate_limiter.py +++ b/app/services/rate_limiter.py @@ -7,6 +7,14 @@ from app.db import get_redis from app.oauth2 import decode_token +INCREMENT_WITH_TTL_SCRIPT = """ +local current = redis.call('INCR', KEYS[1]) +if current == 1 then + redis.call('EXPIRE', KEYS[1], ARGV[1]) +end +return current +""" + def fixed_window_rate_limiter( key_prefix: str, @@ -35,12 +43,18 @@ async def dependency( key = f"{key_prefix}:{identity}:{path_key}" try: - current = await redis.incr(key) - if current == 1: - await redis.expire(key, window_seconds) + try: + current = int(await redis.eval(INCREMENT_WITH_TTL_SCRIPT, 1, key, window_seconds)) + except RedisError: + # Some test or managed Redis setups can disable scripting. + current = await redis.incr(key) + if current == 1: + await redis.expire(key, window_seconds) if current > max_requests: retry_after = await redis.ttl(key) + if retry_after <= 0: + retry_after = window_seconds raise HTTPException( status_code=429, detail="Rate limit exceeded", diff --git a/tests/test_auth_routes.py b/tests/test_auth_routes.py index 324114b..c007dcd 100644 --- a/tests/test_auth_routes.py +++ b/tests/test_auth_routes.py @@ -1,4 +1,5 @@ from httpx import AsyncClient +import asyncio from app.config import get_settings @@ -176,3 +177,17 @@ async def test_bootstrap_admin_works_once_with_valid_key(client: AsyncClient, mo monkeypatch.delenv("ADMIN_BOOTSTRAP_KEY", raising=False) get_settings.cache_clear() + + +async def test_refresh_concurrent_requests_do_not_return_500(client: AsyncClient, user_factory): + user = await user_factory(role="admin") + old_refresh = user["tokens"]["refresh_token"] + + async def do_refresh(): + return await client.post("/api/v1/users/refresh", json={"refresh_token": old_refresh}) + + first, second = await asyncio.gather(do_refresh(), do_refresh()) + statuses = {first.status_code, second.status_code} + + assert 500 not in statuses + assert statuses.issubset({200, 401}) diff --git a/tests/test_dashboard_routes.py b/tests/test_dashboard_routes.py index 38dc830..fb40335 100644 --- a/tests/test_dashboard_routes.py +++ b/tests/test_dashboard_routes.py @@ -1,4 +1,5 @@ from decimal import Decimal +from redis.exceptions import RedisError from httpx import AsyncClient @@ -153,3 +154,22 @@ async def test_dashboard_invalid_date_range_returns_400(client: AsyncClient, use headers=_auth(token), ) assert resp.status_code == 400 + + +async def test_dashboard_summary_falls_back_when_redis_unavailable(client: AsyncClient, user_factory, monkeypatch): + import app.main as app_main + + viewer = await user_factory(role="viewer") + token = viewer["tokens"]["access_token"] + + async def failing_get(*_args, **_kwargs): + raise RedisError("simulated redis read failure") + + async def failing_set(*_args, **_kwargs): + raise RedisError("simulated redis write failure") + + monkeypatch.setattr(app_main.redis_client, "get", failing_get) + monkeypatch.setattr(app_main.redis_client, "set", failing_set) + + resp = await client.get("/api/v1/dashboard/summary", headers=_auth(token)) + assert resp.status_code == 200 diff --git a/tests/test_financial_records_routes.py b/tests/test_financial_records_routes.py index 2ffe3b7..e58721f 100644 --- a/tests/test_financial_records_routes.py +++ b/tests/test_financial_records_routes.py @@ -536,3 +536,21 @@ async def test_import_csv_requires_admin(client: AsyncClient, user_factory): content="amount,record_type,category,entry_date\n100,income,salary,2026-04-08\n", ) assert imported.status_code == 403 + + +async def test_import_csv_rejects_oversized_body(client: AsyncClient, user_factory, monkeypatch): + monkeypatch.setenv("MAX_CSV_IMPORT_BYTES", "16") + get_settings.cache_clear() + + admin = await user_factory(role="admin") + token = admin["tokens"]["access_token"] + + imported = await client.post( + "/api/v1/financial-records/import", + headers={**_auth(token), "Content-Type": "text/csv"}, + content="amount,record_type,category,entry_date\n100,income,salary,2026-04-08\n", + ) + assert imported.status_code == 413 + + monkeypatch.delenv("MAX_CSV_IMPORT_BYTES", raising=False) + get_settings.cache_clear() diff --git a/tests/test_rbac_users_routes.py b/tests/test_rbac_users_routes.py index b8bc944..0799c45 100644 --- a/tests/test_rbac_users_routes.py +++ b/tests/test_rbac_users_routes.py @@ -122,3 +122,19 @@ async def test_dashboard_summary_allows_viewer(client: AsyncClient, user_factory resp = await client.get("/api/v1/dashboard/summary", headers={"Authorization": f"Bearer {token}"}) assert resp.status_code == 200 + + +async def test_admin_update_user_route_is_rate_limited(client: AsyncClient, user_factory): + admin = await user_factory(role="admin") + viewer = await user_factory(role="viewer") + + statuses: list[int] = [] + for _ in range(25): + resp = await client.patch( + f"/api/v1/users/admin/users/{viewer['user']['id']}", + headers={"Authorization": f"Bearer {admin['tokens']['access_token']}"}, + json={"is_active": True}, + ) + statuses.append(resp.status_code) + + assert 429 in statuses