Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,6 @@ Thumbs.db
sample-readme.md
sample_tests.md
sample_setup.md
docs/DASHBOARD_GUIDE.md
docs/DASHBOARD_GUIDE.md

scripts/
2 changes: 2 additions & 0 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class Settings(BaseSettings):

dashboard_cache_ttl_seconds: int = 60
recycle_bin_retention_days: int = 30
max_csv_import_bytes: int = 5 * 1024 * 1024
jwt_clock_skew_seconds: int = 10

@computed_field
@property
Expand Down
4 changes: 4 additions & 0 deletions app/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from contextlib import asynccontextmanager
import logging

from fastapi import FastAPI, WebSocket

Expand All @@ -8,10 +9,13 @@
from app.services.presence_service import handle_presence_socket

settings = get_settings()
logger = logging.getLogger(__name__)


@asynccontextmanager
async def lifespan(_: FastAPI):
if settings.secret_key == "replace-with-strong-secret":
logger.warning("Using default SECRET_KEY placeholder. Set SECRET_KEY for non-local environments.")
yield
await close_connections()

Expand Down
1 change: 1 addition & 0 deletions app/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ async def decode_token(token: str) -> TokenPayload:
token,
settings.secret_key,
[settings.algorithm],
{"leeway": settings.jwt_clock_skew_seconds},
)
return TokenPayload(**payload)
except JWTError as exc:
Expand Down
32 changes: 29 additions & 3 deletions app/routes/financial_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response, status
from redis.asyncio import Redis
from sqlalchemy import delete, func, or_, select
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession

from app.config import get_settings
Expand Down Expand Up @@ -53,6 +54,17 @@ async def _purge_expired_deleted_records(db: AsyncSession) -> None:
await db.commit()


async def _safe_commit(db: AsyncSession, *, integrity_detail: str, generic_detail: str) -> None:
try:
await db.commit()
except IntegrityError as exc:
await db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=integrity_detail) from exc
except SQLAlchemyError as exc:
await db.rollback()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=generic_detail) from exc


@router.post("", response_model=FinancialRecordOut, status_code=status.HTTP_201_CREATED)
async def create_financial_record(
payload: FinancialRecordCreate,
Expand All @@ -78,7 +90,11 @@ async def create_financial_record(
user_id=target_user_id,
)
db.add(row)
await db.commit()
await _safe_commit(
db,
integrity_detail="Unable to create record with provided data",
generic_detail="Failed to create financial record",
)
await db.refresh(row)
await invalidate_dashboard_cache(redis)
return FinancialRecordOut.model_validate(row)
Expand Down Expand Up @@ -208,6 +224,8 @@ async def import_financial_records_csv(
raw_body = await request.body()
if not raw_body:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="CSV body is empty")
if len(raw_body) > get_settings().max_csv_import_bytes:
raise HTTPException(status_code=status.HTTP_413_CONTENT_TOO_LARGE, detail="CSV body too large")

try:
csv_text = raw_body.decode("utf-8")
Expand Down Expand Up @@ -254,7 +272,11 @@ async def import_financial_records_csv(
for _, payload, _ in valid_rows
]
db.add_all(created_rows)
await db.commit()
await _safe_commit(
db,
integrity_detail="CSV contains rows that violate data constraints",
generic_detail="Failed to import CSV records",
)
imported_count = len(created_rows)
await invalidate_dashboard_cache(redis)

Expand Down Expand Up @@ -326,7 +348,11 @@ async def update_financial_record(
else:
setattr(row, key, value)

await db.commit()
await _safe_commit(
db,
integrity_detail="Unable to update record with provided data",
generic_detail="Failed to update financial record",
)
await db.refresh(row)
await invalidate_dashboard_cache(redis)
return FinancialRecordOut.model_validate(row)
Expand Down
1 change: 1 addition & 0 deletions app/routes/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ async def update_user_by_admin(
payload: UserUpdate,
db: AsyncSession = Depends(get_db),
_: User = Depends(require_roles(UserRole.ADMIN)),
__: None = Depends(sensitive_route_limiter),
) -> UserOut:
if payload.role is None and payload.is_active is None:
raise HTTPException(
Expand Down
46 changes: 32 additions & 14 deletions app/services/auth_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from datetime import datetime, timezone
import logging

from sqlalchemy.dialects.postgresql import insert
from redis.asyncio import Redis
from redis.exceptions import RedisError
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

Expand All @@ -15,6 +18,8 @@
)
from app.schemas import Token

logger = logging.getLogger(__name__)


async def authenticate_user(db: AsyncSession, email: str, password: str) -> User | None:
user = await db.scalar(select(User).where(User.email == email))
Expand Down Expand Up @@ -50,27 +55,27 @@ async def issue_token_pair(db: AsyncSession, user_id: int) -> Token:

async def _blacklist_token(
db: AsyncSession,
redis: Redis,
redis: Redis | None,
jti: str,
token_type: str,
expires_at: datetime,
persist_db: bool = True,
) -> None:
existing = await db.scalar(select(TokenBlacklist).where(TokenBlacklist.jti == jti))
if existing is None:
db.add(
TokenBlacklist(
jti=jti,
token_type=token_type,
expires_at=expires_at,
)
if persist_db:
stmt = insert(TokenBlacklist).values(
jti=jti,
token_type=token_type,
expires_at=expires_at,
)
stmt = stmt.on_conflict_do_nothing(index_elements=[TokenBlacklist.jti])
await db.execute(stmt)

ttl_seconds = int((expires_at - datetime.now(timezone.utc)).total_seconds())
if ttl_seconds > 0:
if ttl_seconds > 0 and redis is not None:
try:
await redis.set(f"blacklist:{jti}", "1", ex=ttl_seconds)
except RedisError:
pass
except RedisError as exc:
logger.warning("Failed to write token blacklist to Redis for jti=%s: %s", jti, exc)


async def refresh_access_pair(db: AsyncSession, redis: Redis, refresh_token: str) -> Token:
Expand Down Expand Up @@ -106,14 +111,27 @@ async def refresh_access_pair(db: AsyncSession, redis: Redis, refresh_token: str

await _blacklist_token(
db=db,
redis=redis,
redis=None,
jti=payload.jti,
token_type="refresh",
expires_at=exp_to_datetime(payload.exp),
)

await _store_refresh_token(db, user.id, new_refresh_token)
await db.commit()
try:
await db.commit()
except SQLAlchemyError as exc:
await db.rollback()
raise ValueError("Unable to refresh token at the moment") from exc

await _blacklist_token(
db=db,
redis=redis,
jti=payload.jti,
token_type="refresh",
expires_at=exp_to_datetime(payload.exp),
persist_db=False,
)

return Token(access_token=new_access_token, refresh_token=new_refresh_token)

Expand Down
27 changes: 17 additions & 10 deletions app/services/dashboard_service.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import json
import logging
from datetime import date
from decimal import Decimal

from redis.asyncio import Redis
from redis.exceptions import RedisError
from sqlalchemy import case, func, select
from sqlalchemy.ext.asyncio import AsyncSession

Expand All @@ -11,6 +13,7 @@
from app.schemas import CategoryTotal, MonthlyTrendPoint, RecentActivityItem, SummaryTotals

settings = get_settings()
logger = logging.getLogger(__name__)


def _decimal_to_str(value: Decimal) -> str:
Expand All @@ -37,7 +40,8 @@ async def invalidate_dashboard_cache(redis: Redis) -> None:
keys = await redis.keys("dashboard:*")
if keys:
await redis.delete(*keys)
except Exception:
except RedisError as exc:
logger.warning("Failed to invalidate dashboard cache: %s", exc)
return


Expand All @@ -50,7 +54,8 @@ async def get_summary_totals(
key = _cache_key("totals", start_date, end_date)
try:
cached = await redis.get(key)
except Exception:
except RedisError as exc:
logger.warning("Failed to read dashboard cache key %s: %s", key, exc)
cached = None
if cached:
payload = json.loads(cached)
Expand Down Expand Up @@ -94,8 +99,8 @@ async def get_summary_totals(
),
ex=settings.dashboard_cache_ttl_seconds,
)
except Exception:
pass
except RedisError as exc:
logger.warning("Failed to write dashboard cache key %s: %s", key, exc)
return result


Expand All @@ -108,7 +113,8 @@ async def get_category_totals(
key = _cache_key("categories", start_date, end_date)
try:
cached = await redis.get(key)
except Exception:
except RedisError as exc:
logger.warning("Failed to read dashboard cache key %s: %s", key, exc)
cached = None
if cached:
payload = json.loads(cached)
Expand All @@ -130,8 +136,8 @@ async def get_category_totals(
),
ex=settings.dashboard_cache_ttl_seconds,
)
except Exception:
pass
except RedisError as exc:
logger.warning("Failed to write dashboard cache key %s: %s", key, exc)
return result


Expand Down Expand Up @@ -169,7 +175,8 @@ async def get_monthly_trends(
key = _cache_key("monthly_trends", start_date, end_date)
try:
cached = await redis.get(key)
except Exception:
except RedisError as exc:
logger.warning("Failed to read dashboard cache key %s: %s", key, exc)
cached = None
if cached:
payload = json.loads(cached)
Expand Down Expand Up @@ -225,6 +232,6 @@ async def get_monthly_trends(
),
ex=settings.dashboard_cache_ttl_seconds,
)
except Exception:
pass
except RedisError as exc:
logger.warning("Failed to write dashboard cache key %s: %s", key, exc)
return result
18 changes: 13 additions & 5 deletions app/services/presence_service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import logging

from fastapi import WebSocket
from fastapi import WebSocket, WebSocketDisconnect
from jose import JWTError
from redis.asyncio import Redis
from sqlalchemy.ext.asyncio import AsyncSession
Expand All @@ -12,6 +13,7 @@
from app.oauth2 import decode_token

settings = get_settings()
logger = logging.getLogger(__name__)


def _presence_key(user_id: int) -> str:
Expand Down Expand Up @@ -78,7 +80,8 @@ async def handle_presence_socket(websocket: WebSocket, redis: Redis) -> None:

try:
user_id = await _authenticate_socket_user(websocket, redis)
except ValueError:
except ValueError as exc:
logger.debug("Presence websocket authentication failed: %s", exc)
return

await set_online(redis, user_id)
Expand All @@ -96,7 +99,12 @@ async def handle_presence_socket(websocket: WebSocket, redis: Redis) -> None:
except asyncio.TimeoutError:
await websocket.send_json({"type": "ping"})
await set_online(redis, user_id)
except Exception:
pass
except WebSocketDisconnect:
logger.debug("Presence websocket disconnected for user_id=%s", user_id)
except Exception as exc:
logger.warning("Presence websocket loop error for user_id=%s: %s", user_id, exc)
finally:
await set_offline(redis, user_id)
try:
await set_offline(redis, user_id)
except Exception as exc:
logger.warning("Failed to clear presence state for user_id=%s: %s", user_id, exc)
20 changes: 17 additions & 3 deletions app/services/rate_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
from app.db import get_redis
from app.oauth2 import decode_token

INCREMENT_WITH_TTL_SCRIPT = """
local current = redis.call('INCR', KEYS[1])
if current == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[1])
end
return current
"""


def fixed_window_rate_limiter(
key_prefix: str,
Expand Down Expand Up @@ -35,12 +43,18 @@ async def dependency(
key = f"{key_prefix}:{identity}:{path_key}"

try:
current = await redis.incr(key)
if current == 1:
await redis.expire(key, window_seconds)
try:
current = int(await redis.eval(INCREMENT_WITH_TTL_SCRIPT, 1, key, window_seconds))
except RedisError:
# Some test or managed Redis setups can disable scripting.
current = await redis.incr(key)
if current == 1:
await redis.expire(key, window_seconds)

if current > max_requests:
retry_after = await redis.ttl(key)
if retry_after <= 0:
retry_after = window_seconds
raise HTTPException(
status_code=429,
detail="Rate limit exceeded",
Expand Down
Loading
Loading