"""CSRF 防护 — Token 校验中间件""" import secrets import time import hashlib from typing import Optional from oss.config import get_config from oss.logger.logger import Log class CSRFProtection: """CSRF Token 生成与验证""" def __init__(self, secret: str = None): config = get_config() self._secret = secret or config.get("CSRF_SECRET", "") if not self._secret: self._secret = hashlib.sha256(config.get("API_KEY", "nebula-csrf-default").encode()).hexdigest() self._token_ttl = config.get("CSRF_TOKEN_TTL", 3600) # 默认1小时 def generate_token(self, session_id: str) -> str: """生成 CSRF Token(绑定 session)""" salt = secrets.token_hex(16) timestamp = int(time.time()) raw = f"{session_id}:{salt}:{timestamp}:{self._secret}" token = hashlib.sha256(raw.encode()).hexdigest() return f"{timestamp}:{salt}:{token}" def verify_token(self, session_id: str, token: str) -> bool: """验证 CSRF Token""" try: parts = token.split(":") if len(parts) != 3: return False timestamp, salt, hash_val = parts # 检查过期 if int(time.time()) - int(timestamp) > self._token_ttl: return False expected = hashlib.sha256(f"{session_id}:{salt}:{timestamp}:{self._secret}".encode()).hexdigest() return hash_val == expected except (ValueError, IndexError): return False SAFE_METHODS = {"GET", "HEAD", "OPTIONS"} @staticmethod def is_safe_method(method: str) -> bool: return method.upper() in CSRFProtection.SAFE_METHODS