Add comprehensive backend features and mobile UI improvements
Backend: - Add 2FA authentication with TOTP support - Add API keys management system - Add audit logging for security events - Add file upload/management system - Add notifications system with preferences - Add session management - Add webhooks integration - Add analytics endpoints - Add export functionality - Add password policy enforcement - Add new database migrations for core tables Frontend: - Add module position system (top/bottom sidebar sections) - Add search and notifications module configuration tabs - Add mobile logo replacing hamburger menu - Center page title absolutely when no tabs present - Align sidebar footer toggles with navigation items - Add lighter icon color in dark theme for mobile - Add API keys management page - Add notifications page with context - Add admin analytics and audit logs pages 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -2,5 +2,11 @@
|
||||
|
||||
from app.crud.user import user
|
||||
from app.crud import settings
|
||||
from app.crud.audit_log import audit_log
|
||||
from app.crud.api_key import api_key
|
||||
from app.crud.notification import notification
|
||||
from app.crud.session import session
|
||||
from app.crud.webhook import webhook, webhook_delivery, webhook_service
|
||||
from app.crud.file import file_storage
|
||||
|
||||
__all__ = ["user", "settings"]
|
||||
__all__ = ["user", "settings", "audit_log", "api_key", "notification", "session", "webhook", "webhook_delivery", "webhook_service", "file_storage"]
|
||||
|
||||
184
backend/app/crud/api_key.py
Normal file
184
backend/app/crud/api_key.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""CRUD operations for API Key model."""
|
||||
|
||||
import json
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.api_key import APIKey, generate_api_key, generate_key_prefix
|
||||
from app.schemas.api_key import APIKeyCreate, APIKeyUpdate
|
||||
|
||||
|
||||
def hash_api_key(key: str) -> str:
|
||||
"""Hash an API key for secure storage."""
|
||||
return hashlib.sha256(key.encode()).hexdigest()
|
||||
|
||||
|
||||
class CRUDAPIKey:
|
||||
"""CRUD operations for API Key model."""
|
||||
|
||||
def create(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
obj_in: APIKeyCreate,
|
||||
user_id: str
|
||||
) -> Tuple[APIKey, str]:
|
||||
"""
|
||||
Create a new API key.
|
||||
Returns both the database object and the plain key (shown only once).
|
||||
"""
|
||||
# Generate the actual key
|
||||
plain_key = generate_api_key()
|
||||
key_hash = hash_api_key(plain_key)
|
||||
key_prefix = generate_key_prefix(plain_key)
|
||||
|
||||
# Serialize scopes to JSON
|
||||
scopes_json = json.dumps(obj_in.scopes) if obj_in.scopes else None
|
||||
|
||||
db_obj = APIKey(
|
||||
user_id=user_id,
|
||||
name=obj_in.name,
|
||||
key_hash=key_hash,
|
||||
key_prefix=key_prefix,
|
||||
scopes=scopes_json,
|
||||
expires_at=obj_in.expires_at,
|
||||
is_active=True,
|
||||
usage_count="0"
|
||||
)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
|
||||
return db_obj, plain_key
|
||||
|
||||
def get(self, db: Session, id: str) -> Optional[APIKey]:
|
||||
"""Get an API key by ID."""
|
||||
return db.query(APIKey).filter(APIKey.id == id).first()
|
||||
|
||||
def get_by_key(self, db: Session, plain_key: str) -> Optional[APIKey]:
|
||||
"""Get an API key by the plain key (for authentication)."""
|
||||
key_hash = hash_api_key(plain_key)
|
||||
return db.query(APIKey).filter(APIKey.key_hash == key_hash).first()
|
||||
|
||||
def get_multi_by_user(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
user_id: str,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[APIKey]:
|
||||
"""Get all API keys for a user."""
|
||||
return db.query(APIKey)\
|
||||
.filter(APIKey.user_id == user_id)\
|
||||
.order_by(APIKey.created_at.desc())\
|
||||
.offset(skip)\
|
||||
.limit(limit)\
|
||||
.all()
|
||||
|
||||
def count_by_user(self, db: Session, user_id: str) -> int:
|
||||
"""Count API keys for a user."""
|
||||
return db.query(APIKey).filter(APIKey.user_id == user_id).count()
|
||||
|
||||
def update(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
db_obj: APIKey,
|
||||
obj_in: APIKeyUpdate
|
||||
) -> APIKey:
|
||||
"""Update an API key."""
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
|
||||
# Handle scopes serialization
|
||||
if "scopes" in update_data:
|
||||
update_data["scopes"] = json.dumps(update_data["scopes"]) if update_data["scopes"] else None
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(db_obj, field, value)
|
||||
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
def delete(self, db: Session, *, id: str) -> bool:
|
||||
"""Delete an API key."""
|
||||
obj = db.query(APIKey).filter(APIKey.id == id).first()
|
||||
if obj:
|
||||
db.delete(obj)
|
||||
db.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
def delete_by_user(self, db: Session, *, user_id: str) -> int:
|
||||
"""Delete all API keys for a user."""
|
||||
count = db.query(APIKey).filter(APIKey.user_id == user_id).delete()
|
||||
db.commit()
|
||||
return count
|
||||
|
||||
def revoke(self, db: Session, *, id: str) -> Optional[APIKey]:
|
||||
"""Revoke (deactivate) an API key."""
|
||||
obj = db.query(APIKey).filter(APIKey.id == id).first()
|
||||
if obj:
|
||||
obj.is_active = False
|
||||
db.add(obj)
|
||||
db.commit()
|
||||
db.refresh(obj)
|
||||
return obj
|
||||
|
||||
def record_usage(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
db_obj: APIKey,
|
||||
ip_address: Optional[str] = None
|
||||
) -> APIKey:
|
||||
"""Record API key usage."""
|
||||
db_obj.last_used_at = datetime.utcnow()
|
||||
db_obj.last_used_ip = ip_address
|
||||
db_obj.usage_count = str(int(db_obj.usage_count or "0") + 1)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
def authenticate(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
plain_key: str,
|
||||
ip_address: Optional[str] = None
|
||||
) -> Optional[APIKey]:
|
||||
"""
|
||||
Authenticate with an API key.
|
||||
Returns the key if valid, None otherwise.
|
||||
Also records usage on successful auth.
|
||||
"""
|
||||
api_key = self.get_by_key(db, plain_key)
|
||||
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
if not api_key.is_valid:
|
||||
return None
|
||||
|
||||
# Record usage
|
||||
self.record_usage(db, db_obj=api_key, ip_address=ip_address)
|
||||
|
||||
return api_key
|
||||
|
||||
def get_scopes(self, api_key: APIKey) -> List[str]:
|
||||
"""Get scopes for an API key."""
|
||||
if api_key.scopes:
|
||||
try:
|
||||
return json.loads(api_key.scopes)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
return []
|
||||
|
||||
|
||||
# Create instance
|
||||
api_key = CRUDAPIKey()
|
||||
228
backend/app/crud/audit_log.py
Normal file
228
backend/app/crud/audit_log.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""CRUD operations for Audit Log model."""
|
||||
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func, desc
|
||||
|
||||
from app.models.audit_log import AuditLog
|
||||
from app.schemas.audit_log import AuditLogCreate, AuditLogFilter
|
||||
|
||||
|
||||
class CRUDAuditLog:
|
||||
"""CRUD operations for Audit Log model."""
|
||||
|
||||
def create(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
obj_in: AuditLogCreate
|
||||
) -> AuditLog:
|
||||
"""Create a new audit log entry."""
|
||||
db_obj = AuditLog(
|
||||
user_id=obj_in.user_id,
|
||||
username=obj_in.username,
|
||||
action=obj_in.action,
|
||||
resource_type=obj_in.resource_type,
|
||||
resource_id=obj_in.resource_id,
|
||||
details=obj_in.details,
|
||||
ip_address=obj_in.ip_address,
|
||||
user_agent=obj_in.user_agent,
|
||||
status=obj_in.status
|
||||
)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
def log_action(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
user_id: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
action: str,
|
||||
resource_type: Optional[str] = None,
|
||||
resource_id: Optional[str] = None,
|
||||
details: Optional[dict] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
status: str = "success"
|
||||
) -> AuditLog:
|
||||
"""Convenience method to log an action."""
|
||||
details_str = json.dumps(details) if details else None
|
||||
obj_in = AuditLogCreate(
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
details=details_str,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
status=status
|
||||
)
|
||||
return self.create(db, obj_in=obj_in)
|
||||
|
||||
def get(self, db: Session, id: str) -> Optional[AuditLog]:
|
||||
"""Get a single audit log entry by ID."""
|
||||
return db.query(AuditLog).filter(AuditLog.id == id).first()
|
||||
|
||||
def get_multi(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
filters: Optional[AuditLogFilter] = None
|
||||
) -> tuple[List[AuditLog], int]:
|
||||
"""Get multiple audit log entries with optional filtering."""
|
||||
query = db.query(AuditLog)
|
||||
|
||||
if filters:
|
||||
if filters.user_id:
|
||||
query = query.filter(AuditLog.user_id == filters.user_id)
|
||||
if filters.username:
|
||||
query = query.filter(AuditLog.username.ilike(f"%{filters.username}%"))
|
||||
if filters.action:
|
||||
query = query.filter(AuditLog.action == filters.action)
|
||||
if filters.resource_type:
|
||||
query = query.filter(AuditLog.resource_type == filters.resource_type)
|
||||
if filters.resource_id:
|
||||
query = query.filter(AuditLog.resource_id == filters.resource_id)
|
||||
if filters.status:
|
||||
query = query.filter(AuditLog.status == filters.status)
|
||||
if filters.start_date:
|
||||
query = query.filter(AuditLog.created_at >= filters.start_date)
|
||||
if filters.end_date:
|
||||
query = query.filter(AuditLog.created_at <= filters.end_date)
|
||||
|
||||
total = query.count()
|
||||
items = query.order_by(desc(AuditLog.created_at)).offset(skip).limit(limit).all()
|
||||
|
||||
return items, total
|
||||
|
||||
def get_by_user(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
user_id: str,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[AuditLog]:
|
||||
"""Get audit logs for a specific user."""
|
||||
return db.query(AuditLog)\
|
||||
.filter(AuditLog.user_id == user_id)\
|
||||
.order_by(desc(AuditLog.created_at))\
|
||||
.offset(skip)\
|
||||
.limit(limit)\
|
||||
.all()
|
||||
|
||||
def get_recent(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
hours: int = 24,
|
||||
limit: int = 100
|
||||
) -> List[AuditLog]:
|
||||
"""Get recent audit logs within specified hours."""
|
||||
since = datetime.utcnow() - timedelta(hours=hours)
|
||||
return db.query(AuditLog)\
|
||||
.filter(AuditLog.created_at >= since)\
|
||||
.order_by(desc(AuditLog.created_at))\
|
||||
.limit(limit)\
|
||||
.all()
|
||||
|
||||
def get_stats(self, db: Session) -> dict[str, Any]:
|
||||
"""Get audit log statistics."""
|
||||
now = datetime.utcnow()
|
||||
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
week_start = today_start - timedelta(days=today_start.weekday())
|
||||
month_start = today_start.replace(day=1)
|
||||
|
||||
# Total entries
|
||||
total = db.query(func.count(AuditLog.id)).scalar()
|
||||
|
||||
# Entries today
|
||||
entries_today = db.query(func.count(AuditLog.id))\
|
||||
.filter(AuditLog.created_at >= today_start)\
|
||||
.scalar()
|
||||
|
||||
# Entries this week
|
||||
entries_week = db.query(func.count(AuditLog.id))\
|
||||
.filter(AuditLog.created_at >= week_start)\
|
||||
.scalar()
|
||||
|
||||
# Entries this month
|
||||
entries_month = db.query(func.count(AuditLog.id))\
|
||||
.filter(AuditLog.created_at >= month_start)\
|
||||
.scalar()
|
||||
|
||||
# Actions breakdown
|
||||
actions_query = db.query(
|
||||
AuditLog.action,
|
||||
func.count(AuditLog.id).label('count')
|
||||
).group_by(AuditLog.action).all()
|
||||
actions_breakdown = {action: count for action, count in actions_query}
|
||||
|
||||
# Top users (by action count)
|
||||
top_users_query = db.query(
|
||||
AuditLog.user_id,
|
||||
AuditLog.username,
|
||||
func.count(AuditLog.id).label('count')
|
||||
).filter(AuditLog.user_id.isnot(None))\
|
||||
.group_by(AuditLog.user_id, AuditLog.username)\
|
||||
.order_by(desc('count'))\
|
||||
.limit(10)\
|
||||
.all()
|
||||
top_users = [
|
||||
{"user_id": uid, "username": uname, "count": count}
|
||||
for uid, uname, count in top_users_query
|
||||
]
|
||||
|
||||
# Recent failures (last 24h)
|
||||
recent_failures = db.query(func.count(AuditLog.id))\
|
||||
.filter(AuditLog.status == "failure")\
|
||||
.filter(AuditLog.created_at >= today_start - timedelta(days=1))\
|
||||
.scalar()
|
||||
|
||||
return {
|
||||
"total_entries": total or 0,
|
||||
"entries_today": entries_today or 0,
|
||||
"entries_this_week": entries_week or 0,
|
||||
"entries_this_month": entries_month or 0,
|
||||
"actions_breakdown": actions_breakdown,
|
||||
"top_users": top_users,
|
||||
"recent_failures": recent_failures or 0
|
||||
}
|
||||
|
||||
def delete_old(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
days: int = 90
|
||||
) -> int:
|
||||
"""Delete audit logs older than specified days."""
|
||||
cutoff = datetime.utcnow() - timedelta(days=days)
|
||||
count = db.query(AuditLog)\
|
||||
.filter(AuditLog.created_at < cutoff)\
|
||||
.delete()
|
||||
db.commit()
|
||||
return count
|
||||
|
||||
def get_distinct_actions(self, db: Session) -> List[str]:
|
||||
"""Get list of distinct action types."""
|
||||
result = db.query(AuditLog.action).distinct().all()
|
||||
return [r[0] for r in result]
|
||||
|
||||
def get_distinct_resource_types(self, db: Session) -> List[str]:
|
||||
"""Get list of distinct resource types."""
|
||||
result = db.query(AuditLog.resource_type)\
|
||||
.filter(AuditLog.resource_type.isnot(None))\
|
||||
.distinct().all()
|
||||
return [r[0] for r in result]
|
||||
|
||||
|
||||
# Create instance
|
||||
audit_log = CRUDAuditLog()
|
||||
264
backend/app/crud/file.py
Normal file
264
backend/app/crud/file.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""CRUD operations and storage service for files."""
|
||||
|
||||
import os
|
||||
import json
|
||||
import hashlib
|
||||
import shutil
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, BinaryIO
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.file import StoredFile
|
||||
from app.schemas.file import FileCreate, FileUpdate, ALLOWED_CONTENT_TYPES, MAX_FILE_SIZE
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class FileStorageService:
|
||||
"""Service for handling file storage operations."""
|
||||
|
||||
def __init__(self, storage_path: str = None):
|
||||
"""Initialize the storage service."""
|
||||
configured_path = storage_path or os.getenv("FILE_STORAGE_PATH")
|
||||
if configured_path:
|
||||
self.storage_path = Path(configured_path)
|
||||
else:
|
||||
# Prefer persistent storage when running in the container (bind-mounted /config).
|
||||
self.storage_path = Path("/config/uploads") if Path("/config").exists() else Path("./uploads")
|
||||
self.storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _get_file_path(self, file_id: str, filename: str) -> Path:
|
||||
"""Generate the storage path for a file."""
|
||||
# Organize files by date and ID for better management
|
||||
date_prefix = datetime.utcnow().strftime("%Y/%m")
|
||||
dir_path = self.storage_path / date_prefix
|
||||
dir_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Use file ID + original extension
|
||||
ext = Path(filename).suffix
|
||||
return dir_path / f"{file_id}{ext}"
|
||||
|
||||
def _calculate_hash(self, file: BinaryIO) -> str:
|
||||
"""Calculate SHA-256 hash of file contents."""
|
||||
sha256 = hashlib.sha256()
|
||||
for chunk in iter(lambda: file.read(8192), b""):
|
||||
sha256.update(chunk)
|
||||
file.seek(0) # Reset file position
|
||||
return sha256.hexdigest()
|
||||
|
||||
def save_file(
|
||||
self,
|
||||
file: BinaryIO,
|
||||
filename: str,
|
||||
file_id: str
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Save a file to storage.
|
||||
Returns (relative_path, file_hash).
|
||||
"""
|
||||
# Calculate hash
|
||||
file_hash = self._calculate_hash(file)
|
||||
|
||||
# Get storage path
|
||||
file_path = self._get_file_path(file_id, filename)
|
||||
relative_path = str(file_path.relative_to(self.storage_path))
|
||||
|
||||
# Save file
|
||||
with open(file_path, "wb") as f:
|
||||
shutil.copyfileobj(file, f)
|
||||
|
||||
return relative_path, file_hash
|
||||
|
||||
def get_file_path(self, relative_path: str) -> Path:
|
||||
"""Get the full path for a stored file."""
|
||||
return self.storage_path / relative_path
|
||||
|
||||
def delete_file(self, relative_path: str) -> bool:
|
||||
"""Delete a file from storage."""
|
||||
try:
|
||||
file_path = self.storage_path / relative_path
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def file_exists(self, relative_path: str) -> bool:
|
||||
"""Check if a file exists in storage."""
|
||||
return (self.storage_path / relative_path).exists()
|
||||
|
||||
|
||||
class CRUDFile:
|
||||
"""CRUD operations for stored files."""
|
||||
|
||||
def __init__(self):
|
||||
self.storage = FileStorageService()
|
||||
|
||||
def get(self, db: Session, id: str) -> Optional[StoredFile]:
|
||||
"""Get a file by ID."""
|
||||
return db.query(StoredFile).filter(
|
||||
StoredFile.id == id,
|
||||
StoredFile.is_deleted == False
|
||||
).first()
|
||||
|
||||
def get_by_hash(self, db: Session, file_hash: str) -> Optional[StoredFile]:
|
||||
"""Get a file by its hash (for deduplication)."""
|
||||
return db.query(StoredFile).filter(
|
||||
StoredFile.file_hash == file_hash,
|
||||
StoredFile.is_deleted == False
|
||||
).first()
|
||||
|
||||
def get_multi(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
uploaded_by: Optional[str] = None,
|
||||
is_public: Optional[bool] = None,
|
||||
content_type: Optional[str] = None
|
||||
) -> List[StoredFile]:
|
||||
"""Get multiple files with filtering."""
|
||||
query = db.query(StoredFile).filter(StoredFile.is_deleted == False)
|
||||
|
||||
if uploaded_by:
|
||||
query = query.filter(StoredFile.uploaded_by == uploaded_by)
|
||||
if is_public is not None:
|
||||
query = query.filter(StoredFile.is_public == is_public)
|
||||
if content_type:
|
||||
query = query.filter(StoredFile.content_type.like(f"{content_type}%"))
|
||||
|
||||
return query.order_by(StoredFile.created_at.desc()).offset(skip).limit(limit).all()
|
||||
|
||||
def count(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
uploaded_by: Optional[str] = None,
|
||||
is_public: Optional[bool] = None
|
||||
) -> int:
|
||||
"""Count files with optional filtering."""
|
||||
query = db.query(StoredFile).filter(StoredFile.is_deleted == False)
|
||||
|
||||
if uploaded_by:
|
||||
query = query.filter(StoredFile.uploaded_by == uploaded_by)
|
||||
if is_public is not None:
|
||||
query = query.filter(StoredFile.is_public == is_public)
|
||||
|
||||
return query.count()
|
||||
|
||||
def create(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
file: BinaryIO,
|
||||
filename: str,
|
||||
content_type: Optional[str],
|
||||
size_bytes: int,
|
||||
uploaded_by: Optional[str] = None,
|
||||
metadata: Optional[FileCreate] = None
|
||||
) -> StoredFile:
|
||||
"""Create a new file record and save the file."""
|
||||
file_id = str(uuid.uuid4())
|
||||
|
||||
# Save file to storage
|
||||
storage_path, file_hash = self.storage.save_file(file, filename, file_id)
|
||||
|
||||
# Create database record
|
||||
db_obj = StoredFile(
|
||||
id=file_id,
|
||||
original_filename=filename,
|
||||
content_type=content_type,
|
||||
size_bytes=size_bytes,
|
||||
storage_path=storage_path,
|
||||
storage_type="local",
|
||||
file_hash=file_hash,
|
||||
uploaded_by=uploaded_by,
|
||||
description=metadata.description if metadata else None,
|
||||
tags=json.dumps(metadata.tags) if metadata and metadata.tags else None,
|
||||
is_public=metadata.is_public if metadata else False
|
||||
)
|
||||
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
def update(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
db_obj: StoredFile,
|
||||
obj_in: FileUpdate
|
||||
) -> StoredFile:
|
||||
"""Update file metadata."""
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
|
||||
if "tags" in update_data and update_data["tags"] is not None:
|
||||
update_data["tags"] = json.dumps(update_data["tags"])
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(db_obj, field, value)
|
||||
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
def soft_delete(self, db: Session, *, id: str) -> Optional[StoredFile]:
|
||||
"""Soft delete a file (marks as deleted but keeps record)."""
|
||||
obj = db.query(StoredFile).filter(StoredFile.id == id).first()
|
||||
if obj:
|
||||
obj.is_deleted = True
|
||||
obj.deleted_at = datetime.utcnow()
|
||||
db.add(obj)
|
||||
db.commit()
|
||||
db.refresh(obj)
|
||||
return obj
|
||||
|
||||
def hard_delete(self, db: Session, *, id: str) -> bool:
|
||||
"""Permanently delete a file and its record."""
|
||||
obj = db.query(StoredFile).filter(StoredFile.id == id).first()
|
||||
if obj:
|
||||
# Delete physical file
|
||||
self.storage.delete_file(obj.storage_path)
|
||||
# Delete database record
|
||||
db.delete(obj)
|
||||
db.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_file_content(self, db_obj: StoredFile) -> Optional[Path]:
|
||||
"""Get the path to the actual file."""
|
||||
file_path = self.storage.get_file_path(db_obj.storage_path)
|
||||
if file_path.exists():
|
||||
return file_path
|
||||
return None
|
||||
|
||||
def validate_upload(
|
||||
self,
|
||||
content_type: Optional[str],
|
||||
size_bytes: int,
|
||||
allowed_types: List[str] = None,
|
||||
max_size: int = None
|
||||
) -> tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Validate a file upload.
|
||||
Returns (is_valid, error_message).
|
||||
"""
|
||||
allowed = allowed_types or ALLOWED_CONTENT_TYPES
|
||||
max_size = max_size or MAX_FILE_SIZE
|
||||
|
||||
if size_bytes > max_size:
|
||||
return False, f"File size exceeds maximum allowed ({max_size // (1024*1024)} MB)"
|
||||
|
||||
if content_type and content_type not in allowed:
|
||||
return False, f"File type '{content_type}' is not allowed"
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
# Singleton instances
|
||||
file_storage = CRUDFile()
|
||||
233
backend/app/crud/notification.py
Normal file
233
backend/app/crud/notification.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""CRUD operations for Notification model."""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func, desc
|
||||
|
||||
from app.models.notification import Notification
|
||||
from app.schemas.notification import NotificationCreate
|
||||
|
||||
|
||||
class CRUDNotification:
|
||||
"""CRUD operations for Notification model."""
|
||||
|
||||
def create(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
obj_in: NotificationCreate
|
||||
) -> Notification:
|
||||
"""Create a new notification."""
|
||||
extra_data_str = json.dumps(obj_in.extra_data) if obj_in.extra_data else None
|
||||
|
||||
db_obj = Notification(
|
||||
user_id=obj_in.user_id,
|
||||
title=obj_in.title,
|
||||
message=obj_in.message,
|
||||
type=obj_in.type,
|
||||
link=obj_in.link,
|
||||
extra_data=extra_data_str,
|
||||
is_read=False
|
||||
)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
def create_for_user(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
user_id: str,
|
||||
title: str,
|
||||
message: Optional[str] = None,
|
||||
type: str = "info",
|
||||
link: Optional[str] = None,
|
||||
extra_data: Optional[dict] = None
|
||||
) -> Notification:
|
||||
"""Convenience method to create a notification for a user."""
|
||||
obj_in = NotificationCreate(
|
||||
user_id=user_id,
|
||||
title=title,
|
||||
message=message,
|
||||
type=type,
|
||||
link=link,
|
||||
extra_data=extra_data
|
||||
)
|
||||
return self.create(db, obj_in=obj_in)
|
||||
|
||||
def create_for_all_users(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
title: str,
|
||||
message: Optional[str] = None,
|
||||
type: str = "system",
|
||||
link: Optional[str] = None,
|
||||
extra_data: Optional[dict] = None
|
||||
) -> int:
|
||||
"""Create a notification for all users (system notification)."""
|
||||
from app.models.user import User
|
||||
|
||||
users = db.query(User).filter(User.is_active == True).all()
|
||||
count = 0
|
||||
|
||||
for user in users:
|
||||
self.create_for_user(
|
||||
db,
|
||||
user_id=user.id,
|
||||
title=title,
|
||||
message=message,
|
||||
type=type,
|
||||
link=link,
|
||||
extra_data=extra_data
|
||||
)
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
def get(self, db: Session, id: str) -> Optional[Notification]:
|
||||
"""Get a notification by ID."""
|
||||
return db.query(Notification).filter(Notification.id == id).first()
|
||||
|
||||
def get_multi_by_user(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
user_id: str,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
unread_only: bool = False
|
||||
) -> List[Notification]:
|
||||
"""Get notifications for a user."""
|
||||
query = db.query(Notification).filter(Notification.user_id == user_id)
|
||||
|
||||
if unread_only:
|
||||
query = query.filter(Notification.is_read == False)
|
||||
|
||||
return query.order_by(desc(Notification.created_at))\
|
||||
.offset(skip)\
|
||||
.limit(limit)\
|
||||
.all()
|
||||
|
||||
def count_by_user(self, db: Session, user_id: str) -> int:
|
||||
"""Count total notifications for a user."""
|
||||
return db.query(Notification).filter(Notification.user_id == user_id).count()
|
||||
|
||||
def count_unread_by_user(self, db: Session, user_id: str) -> int:
|
||||
"""Count unread notifications for a user."""
|
||||
return db.query(Notification)\
|
||||
.filter(Notification.user_id == user_id)\
|
||||
.filter(Notification.is_read == False)\
|
||||
.count()
|
||||
|
||||
def mark_as_read(self, db: Session, *, id: str, user_id: str) -> Optional[Notification]:
|
||||
"""Mark a notification as read."""
|
||||
db_obj = db.query(Notification)\
|
||||
.filter(Notification.id == id)\
|
||||
.filter(Notification.user_id == user_id)\
|
||||
.first()
|
||||
|
||||
if db_obj and not db_obj.is_read:
|
||||
db_obj.is_read = True
|
||||
db_obj.read_at = datetime.utcnow()
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
|
||||
return db_obj
|
||||
|
||||
def mark_all_as_read(self, db: Session, *, user_id: str) -> int:
|
||||
"""Mark all notifications as read for a user."""
|
||||
count = db.query(Notification)\
|
||||
.filter(Notification.user_id == user_id)\
|
||||
.filter(Notification.is_read == False)\
|
||||
.update({
|
||||
"is_read": True,
|
||||
"read_at": datetime.utcnow()
|
||||
})
|
||||
db.commit()
|
||||
return count
|
||||
|
||||
def mark_multiple_as_read(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
user_id: str,
|
||||
notification_ids: List[str]
|
||||
) -> int:
|
||||
"""Mark multiple notifications as read."""
|
||||
count = db.query(Notification)\
|
||||
.filter(Notification.id.in_(notification_ids))\
|
||||
.filter(Notification.user_id == user_id)\
|
||||
.filter(Notification.is_read == False)\
|
||||
.update({
|
||||
"is_read": True,
|
||||
"read_at": datetime.utcnow()
|
||||
}, synchronize_session=False)
|
||||
db.commit()
|
||||
return count
|
||||
|
||||
def delete(self, db: Session, *, id: str, user_id: str) -> bool:
|
||||
"""Delete a notification."""
|
||||
obj = db.query(Notification)\
|
||||
.filter(Notification.id == id)\
|
||||
.filter(Notification.user_id == user_id)\
|
||||
.first()
|
||||
|
||||
if obj:
|
||||
db.delete(obj)
|
||||
db.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
def delete_all_read(self, db: Session, *, user_id: str) -> int:
|
||||
"""Delete all read notifications for a user."""
|
||||
count = db.query(Notification)\
|
||||
.filter(Notification.user_id == user_id)\
|
||||
.filter(Notification.is_read == True)\
|
||||
.delete()
|
||||
db.commit()
|
||||
return count
|
||||
|
||||
def delete_multiple(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
user_id: str,
|
||||
notification_ids: List[str]
|
||||
) -> int:
|
||||
"""Delete multiple notifications."""
|
||||
count = db.query(Notification)\
|
||||
.filter(Notification.id.in_(notification_ids))\
|
||||
.filter(Notification.user_id == user_id)\
|
||||
.delete(synchronize_session=False)
|
||||
db.commit()
|
||||
return count
|
||||
|
||||
def get_stats_by_user(self, db: Session, user_id: str) -> dict:
|
||||
"""Get notification statistics for a user."""
|
||||
total = self.count_by_user(db, user_id)
|
||||
unread = self.count_unread_by_user(db, user_id)
|
||||
|
||||
# Count by type
|
||||
type_counts = db.query(
|
||||
Notification.type,
|
||||
func.count(Notification.id).label('count')
|
||||
).filter(Notification.user_id == user_id)\
|
||||
.group_by(Notification.type)\
|
||||
.all()
|
||||
|
||||
by_type = {t: c for t, c in type_counts}
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"unread": unread,
|
||||
"by_type": by_type
|
||||
}
|
||||
|
||||
|
||||
# Create instance
|
||||
notification = CRUDNotification()
|
||||
274
backend/app/crud/session.py
Normal file
274
backend/app/crud/session.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""CRUD operations for User Session model."""
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc
|
||||
|
||||
from app.models.session import UserSession
|
||||
from app.schemas.session import SessionCreate
|
||||
|
||||
|
||||
def hash_token(token: str) -> str:
|
||||
"""Hash a token for secure storage."""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
|
||||
def parse_user_agent(user_agent: str) -> dict:
|
||||
"""Parse user agent string to extract device info."""
|
||||
result = {
|
||||
"device_type": "desktop",
|
||||
"browser": "Unknown",
|
||||
"os": "Unknown"
|
||||
}
|
||||
|
||||
if not user_agent:
|
||||
return result
|
||||
|
||||
ua_lower = user_agent.lower()
|
||||
|
||||
# Detect device type
|
||||
if "mobile" in ua_lower or "android" in ua_lower and "mobile" in ua_lower:
|
||||
result["device_type"] = "mobile"
|
||||
elif "tablet" in ua_lower or "ipad" in ua_lower:
|
||||
result["device_type"] = "tablet"
|
||||
|
||||
# Detect OS
|
||||
if "windows" in ua_lower:
|
||||
result["os"] = "Windows"
|
||||
elif "mac os" in ua_lower or "macintosh" in ua_lower:
|
||||
result["os"] = "macOS"
|
||||
elif "linux" in ua_lower:
|
||||
result["os"] = "Linux"
|
||||
elif "android" in ua_lower:
|
||||
result["os"] = "Android"
|
||||
elif "iphone" in ua_lower or "ipad" in ua_lower:
|
||||
result["os"] = "iOS"
|
||||
|
||||
# Detect browser
|
||||
if "firefox" in ua_lower:
|
||||
result["browser"] = "Firefox"
|
||||
elif "edg" in ua_lower:
|
||||
result["browser"] = "Edge"
|
||||
elif "chrome" in ua_lower:
|
||||
result["browser"] = "Chrome"
|
||||
elif "safari" in ua_lower:
|
||||
result["browser"] = "Safari"
|
||||
elif "opera" in ua_lower:
|
||||
result["browser"] = "Opera"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class CRUDSession:
|
||||
"""CRUD operations for User Session model."""
|
||||
|
||||
def create(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
user_id: str,
|
||||
token: str,
|
||||
user_agent: Optional[str] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
expires_at: Optional[datetime] = None
|
||||
) -> UserSession:
|
||||
"""Create a new session."""
|
||||
token_hash = hash_token(token)
|
||||
parsed_ua = parse_user_agent(user_agent or "")
|
||||
|
||||
# Generate device name
|
||||
device_name = f"{parsed_ua['browser']} on {parsed_ua['os']}"
|
||||
|
||||
db_obj = UserSession(
|
||||
user_id=user_id,
|
||||
token_hash=token_hash,
|
||||
device_name=device_name,
|
||||
device_type=parsed_ua["device_type"],
|
||||
browser=parsed_ua["browser"],
|
||||
os=parsed_ua["os"],
|
||||
user_agent=user_agent[:500] if user_agent else None,
|
||||
ip_address=ip_address,
|
||||
expires_at=expires_at,
|
||||
is_active=True
|
||||
)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
def get(self, db: Session, id: str) -> Optional[UserSession]:
|
||||
"""Get a session by ID."""
|
||||
return db.query(UserSession).filter(UserSession.id == id).first()
|
||||
|
||||
def get_by_token(self, db: Session, token: str) -> Optional[UserSession]:
|
||||
"""Get a session by token."""
|
||||
token_hash = hash_token(token)
|
||||
return db.query(UserSession).filter(UserSession.token_hash == token_hash).first()
|
||||
|
||||
def get_multi_by_user(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
user_id: str,
|
||||
active_only: bool = True
|
||||
) -> List[UserSession]:
|
||||
"""Get all sessions for a user."""
|
||||
query = db.query(UserSession).filter(UserSession.user_id == user_id)
|
||||
|
||||
if active_only:
|
||||
query = query.filter(UserSession.is_active == True)
|
||||
|
||||
return query.order_by(desc(UserSession.last_active_at)).all()
|
||||
|
||||
def count_by_user(self, db: Session, user_id: str, active_only: bool = True) -> int:
|
||||
"""Count sessions for a user."""
|
||||
query = db.query(UserSession).filter(UserSession.user_id == user_id)
|
||||
if active_only:
|
||||
query = query.filter(UserSession.is_active == True)
|
||||
return query.count()
|
||||
|
||||
def count_active_by_user(self, db: Session, user_id: str) -> int:
|
||||
"""Count active sessions for a user."""
|
||||
return self.count_by_user(db, user_id, active_only=True)
|
||||
|
||||
def update_activity(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
token: str,
|
||||
ip_address: Optional[str] = None
|
||||
) -> Optional[UserSession]:
|
||||
"""Update session last activity."""
|
||||
token_hash = hash_token(token)
|
||||
db_obj = db.query(UserSession).filter(UserSession.token_hash == token_hash).first()
|
||||
|
||||
if db_obj and db_obj.is_active:
|
||||
db_obj.last_active_at = datetime.utcnow()
|
||||
if ip_address:
|
||||
db_obj.ip_address = ip_address
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
|
||||
return db_obj
|
||||
|
||||
def mark_as_current(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
session_id: str,
|
||||
user_id: str
|
||||
) -> None:
|
||||
"""Mark a session as current and unmark others."""
|
||||
# Unmark all sessions for user
|
||||
db.query(UserSession)\
|
||||
.filter(UserSession.user_id == user_id)\
|
||||
.update({"is_current": False})
|
||||
|
||||
# Mark specific session as current
|
||||
db.query(UserSession)\
|
||||
.filter(UserSession.id == session_id)\
|
||||
.update({"is_current": True})
|
||||
|
||||
db.commit()
|
||||
|
||||
def revoke(self, db: Session, *, id: str, user_id: str) -> Optional[UserSession]:
|
||||
"""Revoke a specific session."""
|
||||
db_obj = db.query(UserSession)\
|
||||
.filter(UserSession.id == id)\
|
||||
.filter(UserSession.user_id == user_id)\
|
||||
.first()
|
||||
|
||||
if db_obj:
|
||||
db_obj.is_active = False
|
||||
db_obj.revoked_at = datetime.utcnow()
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
|
||||
return db_obj
|
||||
|
||||
def revoke_all_except(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
user_id: str,
|
||||
except_session_id: Optional[str] = None
|
||||
) -> int:
|
||||
"""Revoke all sessions for a user except the specified one."""
|
||||
query = db.query(UserSession)\
|
||||
.filter(UserSession.user_id == user_id)\
|
||||
.filter(UserSession.is_active == True)
|
||||
|
||||
if except_session_id:
|
||||
query = query.filter(UserSession.id != except_session_id)
|
||||
|
||||
count = query.update({
|
||||
"is_active": False,
|
||||
"revoked_at": datetime.utcnow()
|
||||
})
|
||||
db.commit()
|
||||
return count
|
||||
|
||||
def revoke_multiple(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
user_id: str,
|
||||
session_ids: List[str]
|
||||
) -> int:
|
||||
"""Revoke multiple sessions."""
|
||||
count = db.query(UserSession)\
|
||||
.filter(UserSession.id.in_(session_ids))\
|
||||
.filter(UserSession.user_id == user_id)\
|
||||
.filter(UserSession.is_active == True)\
|
||||
.update({
|
||||
"is_active": False,
|
||||
"revoked_at": datetime.utcnow()
|
||||
}, synchronize_session=False)
|
||||
db.commit()
|
||||
return count
|
||||
|
||||
def cleanup_expired(self, db: Session) -> int:
|
||||
"""Clean up expired sessions."""
|
||||
now = datetime.utcnow()
|
||||
count = db.query(UserSession)\
|
||||
.filter(UserSession.expires_at < now)\
|
||||
.filter(UserSession.is_active == True)\
|
||||
.update({
|
||||
"is_active": False,
|
||||
"revoked_at": now
|
||||
})
|
||||
db.commit()
|
||||
return count
|
||||
|
||||
def is_valid(self, db: Session, token: str) -> bool:
|
||||
"""Check if a session token is valid."""
|
||||
session = self.get_by_token(db, token)
|
||||
if not session:
|
||||
return False
|
||||
|
||||
if not session.is_active:
|
||||
return False
|
||||
|
||||
if session.expires_at and session.expires_at < datetime.utcnow():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def delete_old_inactive(self, db: Session, days: int = 30) -> int:
|
||||
"""Delete old inactive sessions."""
|
||||
cutoff = datetime.utcnow() - timedelta(days=days)
|
||||
count = db.query(UserSession)\
|
||||
.filter(UserSession.is_active == False)\
|
||||
.filter(UserSession.revoked_at < cutoff)\
|
||||
.delete()
|
||||
db.commit()
|
||||
return count
|
||||
|
||||
|
||||
# Create instance
|
||||
session = CRUDSession()
|
||||
345
backend/app/crud/webhook.py
Normal file
345
backend/app/crud/webhook.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""CRUD operations for webhooks."""
|
||||
|
||||
import json
|
||||
import secrets
|
||||
import hashlib
|
||||
import hmac
|
||||
import httpx
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Any
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.webhook import Webhook, WebhookDelivery
|
||||
from app.schemas.webhook import WebhookCreate, WebhookUpdate
|
||||
|
||||
|
||||
class CRUDWebhook:
|
||||
"""CRUD operations for webhooks."""
|
||||
|
||||
def get(self, db: Session, id: str) -> Optional[Webhook]:
|
||||
"""Get a webhook by ID."""
|
||||
return db.query(Webhook).filter(Webhook.id == id).first()
|
||||
|
||||
def get_multi(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: Optional[bool] = None
|
||||
) -> List[Webhook]:
|
||||
"""Get multiple webhooks."""
|
||||
query = db.query(Webhook)
|
||||
if is_active is not None:
|
||||
query = query.filter(Webhook.is_active == is_active)
|
||||
return query.order_by(Webhook.created_at.desc()).offset(skip).limit(limit).all()
|
||||
|
||||
def get_by_event(self, db: Session, event_type: str) -> List[Webhook]:
|
||||
"""Get all active webhooks that subscribe to an event type."""
|
||||
webhooks = db.query(Webhook).filter(Webhook.is_active == True).all()
|
||||
matching = []
|
||||
for webhook in webhooks:
|
||||
events = json.loads(webhook.events) if webhook.events else []
|
||||
if "*" in events or event_type in events:
|
||||
matching.append(webhook)
|
||||
return matching
|
||||
|
||||
def create(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
obj_in: WebhookCreate,
|
||||
created_by: Optional[str] = None
|
||||
) -> Webhook:
|
||||
"""Create a new webhook with a generated secret."""
|
||||
# Generate a secret for signature verification
|
||||
secret = secrets.token_hex(32)
|
||||
|
||||
db_obj = Webhook(
|
||||
name=obj_in.name,
|
||||
url=obj_in.url,
|
||||
secret=secret,
|
||||
events=json.dumps(obj_in.events),
|
||||
is_active=obj_in.is_active,
|
||||
retry_count=obj_in.retry_count,
|
||||
timeout_seconds=obj_in.timeout_seconds,
|
||||
created_by=created_by
|
||||
)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
def update(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
db_obj: Webhook,
|
||||
obj_in: WebhookUpdate
|
||||
) -> Webhook:
|
||||
"""Update a webhook."""
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
|
||||
if "events" in update_data:
|
||||
update_data["events"] = json.dumps(update_data["events"])
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(db_obj, field, value)
|
||||
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
def delete(self, db: Session, *, id: str) -> Optional[Webhook]:
|
||||
"""Delete a webhook."""
|
||||
obj = db.query(Webhook).filter(Webhook.id == id).first()
|
||||
if obj:
|
||||
db.delete(obj)
|
||||
db.commit()
|
||||
return obj
|
||||
|
||||
def regenerate_secret(self, db: Session, *, db_obj: Webhook) -> Webhook:
|
||||
"""Regenerate the webhook secret."""
|
||||
db_obj.secret = secrets.token_hex(32)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
def count(self, db: Session) -> int:
|
||||
"""Count total webhooks."""
|
||||
return db.query(Webhook).count()
|
||||
|
||||
|
||||
class CRUDWebhookDelivery:
|
||||
"""CRUD operations for webhook deliveries."""
|
||||
|
||||
def get(self, db: Session, id: str) -> Optional[WebhookDelivery]:
|
||||
"""Get a delivery by ID."""
|
||||
return db.query(WebhookDelivery).filter(WebhookDelivery.id == id).first()
|
||||
|
||||
def get_by_webhook(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
webhook_id: str,
|
||||
skip: int = 0,
|
||||
limit: int = 50
|
||||
) -> List[WebhookDelivery]:
|
||||
"""Get deliveries for a specific webhook."""
|
||||
return (
|
||||
db.query(WebhookDelivery)
|
||||
.filter(WebhookDelivery.webhook_id == webhook_id)
|
||||
.order_by(WebhookDelivery.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_pending_retries(self, db: Session) -> List[WebhookDelivery]:
|
||||
"""Get deliveries that need to be retried."""
|
||||
now = datetime.utcnow()
|
||||
return (
|
||||
db.query(WebhookDelivery)
|
||||
.filter(
|
||||
WebhookDelivery.status == "failed",
|
||||
WebhookDelivery.next_retry_at <= now
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
def create(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
webhook_id: str,
|
||||
event_type: str,
|
||||
payload: dict
|
||||
) -> WebhookDelivery:
|
||||
"""Create a new webhook delivery record."""
|
||||
db_obj = WebhookDelivery(
|
||||
webhook_id=webhook_id,
|
||||
event_type=event_type,
|
||||
payload=json.dumps(payload),
|
||||
status="pending"
|
||||
)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
db_obj: WebhookDelivery,
|
||||
status: str,
|
||||
status_code: Optional[int] = None,
|
||||
response_body: Optional[str] = None,
|
||||
error_message: Optional[str] = None,
|
||||
schedule_retry: bool = False,
|
||||
max_retries: int = 3
|
||||
) -> WebhookDelivery:
|
||||
"""Update delivery status."""
|
||||
db_obj.status = status
|
||||
db_obj.status_code = status_code
|
||||
db_obj.response_body = response_body[:1000] if response_body else None
|
||||
db_obj.error_message = error_message
|
||||
db_obj.attempt_count += 1
|
||||
|
||||
if status == "success":
|
||||
db_obj.delivered_at = datetime.utcnow()
|
||||
db_obj.next_retry_at = None
|
||||
elif status == "failed" and schedule_retry and db_obj.attempt_count < max_retries:
|
||||
# Exponential backoff: 1min, 5min, 30min
|
||||
delays = [60, 300, 1800]
|
||||
delay = delays[min(db_obj.attempt_count - 1, len(delays) - 1)]
|
||||
db_obj.next_retry_at = datetime.utcnow() + timedelta(seconds=delay)
|
||||
else:
|
||||
db_obj.next_retry_at = None
|
||||
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
|
||||
class WebhookService:
|
||||
"""Service for triggering and delivering webhooks."""
|
||||
|
||||
def __init__(self):
|
||||
self.webhook_crud = CRUDWebhook()
|
||||
self.delivery_crud = CRUDWebhookDelivery()
|
||||
|
||||
def generate_signature(self, payload: str, secret: str) -> str:
|
||||
"""Generate HMAC-SHA256 signature for payload."""
|
||||
return hmac.new(
|
||||
secret.encode(),
|
||||
payload.encode(),
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
async def trigger_event(
|
||||
self,
|
||||
db: Session,
|
||||
event_type: str,
|
||||
payload: dict
|
||||
) -> List[WebhookDelivery]:
|
||||
"""Trigger webhooks for an event."""
|
||||
webhooks = self.webhook_crud.get_by_event(db, event_type)
|
||||
deliveries = []
|
||||
|
||||
for webhook in webhooks:
|
||||
delivery = self.delivery_crud.create(
|
||||
db,
|
||||
webhook_id=webhook.id,
|
||||
event_type=event_type,
|
||||
payload=payload
|
||||
)
|
||||
deliveries.append(delivery)
|
||||
|
||||
# Attempt delivery
|
||||
await self.deliver(db, webhook, delivery)
|
||||
|
||||
return deliveries
|
||||
|
||||
async def deliver(
|
||||
self,
|
||||
db: Session,
|
||||
webhook: Webhook,
|
||||
delivery: WebhookDelivery
|
||||
) -> bool:
|
||||
"""Deliver a webhook."""
|
||||
payload_str = delivery.payload
|
||||
signature = self.generate_signature(payload_str, webhook.secret)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-Webhook-Signature": signature,
|
||||
"X-Webhook-Event": delivery.event_type,
|
||||
"X-Webhook-Delivery-Id": delivery.id
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=webhook.timeout_seconds) as client:
|
||||
response = await client.post(
|
||||
webhook.url,
|
||||
content=payload_str,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
if response.status_code >= 200 and response.status_code < 300:
|
||||
self.delivery_crud.update_status(
|
||||
db,
|
||||
db_obj=delivery,
|
||||
status="success",
|
||||
status_code=response.status_code,
|
||||
response_body=response.text
|
||||
)
|
||||
webhook.success_count += 1
|
||||
webhook.last_triggered_at = datetime.utcnow()
|
||||
db.add(webhook)
|
||||
db.commit()
|
||||
return True
|
||||
else:
|
||||
self.delivery_crud.update_status(
|
||||
db,
|
||||
db_obj=delivery,
|
||||
status="failed",
|
||||
status_code=response.status_code,
|
||||
response_body=response.text,
|
||||
error_message=f"HTTP {response.status_code}",
|
||||
schedule_retry=True,
|
||||
max_retries=webhook.retry_count
|
||||
)
|
||||
webhook.failure_count += 1
|
||||
db.add(webhook)
|
||||
db.commit()
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.delivery_crud.update_status(
|
||||
db,
|
||||
db_obj=delivery,
|
||||
status="failed",
|
||||
error_message=str(e),
|
||||
schedule_retry=True,
|
||||
max_retries=webhook.retry_count
|
||||
)
|
||||
webhook.failure_count += 1
|
||||
db.add(webhook)
|
||||
db.commit()
|
||||
return False
|
||||
|
||||
async def test_webhook(
|
||||
self,
|
||||
db: Session,
|
||||
webhook: Webhook,
|
||||
event_type: str = "test.ping",
|
||||
payload: Optional[dict] = None
|
||||
) -> WebhookDelivery:
|
||||
"""Send a test delivery to a webhook."""
|
||||
if payload is None:
|
||||
payload = {
|
||||
"event": event_type,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"test": True,
|
||||
"message": "This is a test webhook delivery"
|
||||
}
|
||||
|
||||
delivery = self.delivery_crud.create(
|
||||
db,
|
||||
webhook_id=webhook.id,
|
||||
event_type=event_type,
|
||||
payload=payload
|
||||
)
|
||||
|
||||
await self.deliver(db, webhook, delivery)
|
||||
return delivery
|
||||
|
||||
|
||||
# Singleton instances
|
||||
webhook = CRUDWebhook()
|
||||
webhook_delivery = CRUDWebhookDelivery()
|
||||
webhook_service = WebhookService()
|
||||
Reference in New Issue
Block a user