"""CRUD operations for API Key model.""" import json import hashlib from datetime import datetime from typing import Optional, List, Tuple from sqlalchemy.orm import Session from app.models.api_key import APIKey, generate_api_key, generate_key_prefix from app.schemas.api_key import APIKeyCreate, APIKeyUpdate def hash_api_key(key: str) -> str: """Hash an API key for secure storage.""" return hashlib.sha256(key.encode()).hexdigest() class CRUDAPIKey: """CRUD operations for API Key model.""" def create( self, db: Session, *, obj_in: APIKeyCreate, user_id: str ) -> Tuple[APIKey, str]: """ Create a new API key. Returns both the database object and the plain key (shown only once). """ # Generate the actual key plain_key = generate_api_key() key_hash = hash_api_key(plain_key) key_prefix = generate_key_prefix(plain_key) # Serialize scopes to JSON scopes_json = json.dumps(obj_in.scopes) if obj_in.scopes else None db_obj = APIKey( user_id=user_id, name=obj_in.name, key_hash=key_hash, key_prefix=key_prefix, scopes=scopes_json, expires_at=obj_in.expires_at, is_active=True, usage_count="0" ) db.add(db_obj) db.commit() db.refresh(db_obj) return db_obj, plain_key def get(self, db: Session, id: str) -> Optional[APIKey]: """Get an API key by ID.""" return db.query(APIKey).filter(APIKey.id == id).first() def get_by_key(self, db: Session, plain_key: str) -> Optional[APIKey]: """Get an API key by the plain key (for authentication).""" key_hash = hash_api_key(plain_key) return db.query(APIKey).filter(APIKey.key_hash == key_hash).first() def get_multi_by_user( self, db: Session, *, user_id: str, skip: int = 0, limit: int = 100 ) -> List[APIKey]: """Get all API keys for a user.""" return db.query(APIKey)\ .filter(APIKey.user_id == user_id)\ .order_by(APIKey.created_at.desc())\ .offset(skip)\ .limit(limit)\ .all() def count_by_user(self, db: Session, user_id: str) -> int: """Count API keys for a user.""" return db.query(APIKey).filter(APIKey.user_id == user_id).count() def update( self, db: Session, *, db_obj: APIKey, obj_in: APIKeyUpdate ) -> APIKey: """Update an API key.""" update_data = obj_in.model_dump(exclude_unset=True) # Handle scopes serialization if "scopes" in update_data: update_data["scopes"] = json.dumps(update_data["scopes"]) if update_data["scopes"] else None for field, value in update_data.items(): setattr(db_obj, field, value) db.add(db_obj) db.commit() db.refresh(db_obj) return db_obj def delete(self, db: Session, *, id: str) -> bool: """Delete an API key.""" obj = db.query(APIKey).filter(APIKey.id == id).first() if obj: db.delete(obj) db.commit() return True return False def delete_by_user(self, db: Session, *, user_id: str) -> int: """Delete all API keys for a user.""" count = db.query(APIKey).filter(APIKey.user_id == user_id).delete() db.commit() return count def revoke(self, db: Session, *, id: str) -> Optional[APIKey]: """Revoke (deactivate) an API key.""" obj = db.query(APIKey).filter(APIKey.id == id).first() if obj: obj.is_active = False db.add(obj) db.commit() db.refresh(obj) return obj def record_usage( self, db: Session, *, db_obj: APIKey, ip_address: Optional[str] = None ) -> APIKey: """Record API key usage.""" db_obj.last_used_at = datetime.utcnow() db_obj.last_used_ip = ip_address db_obj.usage_count = str(int(db_obj.usage_count or "0") + 1) db.add(db_obj) db.commit() db.refresh(db_obj) return db_obj def authenticate( self, db: Session, *, plain_key: str, ip_address: Optional[str] = None ) -> Optional[APIKey]: """ Authenticate with an API key. Returns the key if valid, None otherwise. Also records usage on successful auth. """ api_key = self.get_by_key(db, plain_key) if not api_key: return None if not api_key.is_valid: return None # Record usage self.record_usage(db, db_obj=api_key, ip_address=ip_address) return api_key def get_scopes(self, api_key: APIKey) -> List[str]: """Get scopes for an API key.""" if api_key.scopes: try: return json.loads(api_key.scopes) except json.JSONDecodeError: return [] return [] # Create instance api_key = CRUDAPIKey()