"""中间件链 - CORS/鉴权/日志/限流/CSRF/输入验证等""" import json import time import threading from collections import deque from typing import Callable, Optional, Any from oss.config import get_config from oss.logger.logger import Log from .server import Request, Response from .rate_limiter import RateLimitMiddleware class Middleware: """中间件基类""" def process(self, ctx: dict[str, Any], next_fn: Callable) -> Optional[Response]: return next_fn() class CorsMiddleware(Middleware): """CORS 中间件""" def process(self, ctx: dict, next_fn: Callable) -> Optional[Response]: config = get_config() allowed_origins = config.get("CORS_ALLOWED_ORIGINS", ["http://localhost:3000", "http://127.0.0.1:3000"]) req = ctx.get("request") origin = req.headers.get("Origin", "") if req else "" if not allowed_origins or not origin: return next_fn() if origin in allowed_origins or "*" in allowed_origins: ctx["response_headers"] = { "Access-Control-Allow-Origin": origin, "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS", "Access-Control-Allow-Headers": "Content-Type, Authorization", "Access-Control-Allow-Credentials": "true", } return next_fn() class AuthMiddleware(Middleware): """鉴权中间件 - JWT + API_KEY 双模式认证""" @staticmethod def _get_public_paths() -> set: """获取公开路径白名单,优先从配置读取""" config = get_config() configured = config.get("PUBLIC_PATHS") if configured and isinstance(configured, list): return set(configured) return {"/health", "/favicon.ico", "/api/status", "/api/health", "/api/login", "/metrics"} def process(self, ctx: dict, next_fn: Callable) -> Optional[Response]: config = get_config() api_key = config.get("API_KEY", "") public_paths = self._get_public_paths() req = ctx.get("request") if req and req.path in public_paths: return next_fn() if req and req.method == "OPTIONS": return next_fn() if not api_key: # 无 API_KEY 时尝试 JWT 鉴权 auth_header = req.headers.get("Authorization", "") if req else "" token = auth_header.removeprefix("Bearer ").strip() if token: from oss.core.security.jwt_auth import verify_token payload = verify_token(token) if payload: ctx["user"] = payload return next_fn() return Response( status=401, body=json.dumps({"error": "Unauthorized", "message": "Token 无效或已过期"}), headers={"Content-Type": "application/json"}, ) return next_fn() # API_KEY 模式 auth_header = req.headers.get("Authorization", "") if req else "" token = auth_header.removeprefix("Bearer ").strip() if token == api_key and token: return next_fn() Log.warn("Core", f"鉴权失败: {req.method} {req.path}" if req else "鉴权失败") return Response( status=401, body=json.dumps({"error": "Unauthorized", "message": "需要有效的认证凭据"}), headers={"Content-Type": "application/json"}, ) class LoggerMiddleware(Middleware): """日志中间件""" _silent_paths = {"/api/dashboard/stats", "/favicon.ico", "/health"} def process(self, ctx: dict, next_fn: Callable) -> Optional[Response]: req = ctx.get("request") if req and req.path not in self._silent_paths: Log.info("Core", f"{req.method} {req.path}") return next_fn() class CSRFMiddleware(Middleware): """CSRF 防护中间件""" def process(self, ctx: dict, next_fn: Callable) -> Optional[Response]: config = get_config() if not config.get("CSRF_ENABLED", True): return next_fn() req = ctx.get("request") if not req or CSRFProtection.is_safe_method(req.method): return next_fn() # 从 Header 或 Body 中获取 CSRF Token token = req.headers.get("X-CSRF-Token", "") session_id = req.headers.get("X-Session-Id", "") if not token or not session_id: return Response( status=403, body=json.dumps({"error": "Forbidden", "message": "缺少 CSRF Token"}), headers={"Content-Type": "application/json"}, ) from oss.core.security.csrf import CSRFProtection csrf = CSRFProtection() if not csrf.verify_token(session_id, token): return Response( status=403, body=json.dumps({"error": "Forbidden", "message": "CSRF Token 无效"}), headers={"Content-Type": "application/json"}, ) return next_fn() class InputValidationMiddleware(Middleware): """输入验证中间件""" def process(self, ctx: dict, next_fn: Callable) -> Optional[Response]: config = get_config() if not config.get("INPUT_VALIDATION_ENABLED", True): return next_fn() return next_fn() # 具体 schema 校验在路由 handler 中按需调用 class MiddlewareChain: """中间件链""" def __init__(self): self.middlewares: list[Middleware] = [] self.add(CorsMiddleware()) self.add(AuthMiddleware()) self.add(CSRFMiddleware()) self.add(InputValidationMiddleware()) self.add(LoggerMiddleware()) self.add(RateLimitMiddleware()) def add(self, middleware: Middleware): self.middlewares.append(middleware) def run(self, ctx: dict[str, Any]) -> Optional[Response]: idx = 0 def next_fn(): nonlocal idx if idx < len(self.middlewares): mw = self.middlewares[idx] idx += 1 return mw.process(ctx, next_fn) return None resp = next_fn() response_headers = ctx.get("response_headers") if response_headers: ctx["_cors_headers"] = response_headers return resp