"""中间件链 - CORS/鉴权/日志/限流/CSRF/输入验证等""" import json import time import threading from collections import deque from typing import Callable, Optional, Any from oss.config import get_config from oss.logger.logger import Log from .server import Request, Response from .rate_limiter import RateLimitMiddleware class Middleware: """中间件基类""" def process(self, ctx: dict[str, Any], next_fn: Callable) -> Optional[Response]: return next_fn() class CorsMiddleware(Middleware): """CORS 中间件""" def process(self, ctx: dict, next_fn: Callable) -> Optional[Response]: config = get_config() allowed_origins = config.get("CORS_ALLOWED_ORIGINS", ["http://localhost:3000", "http://127.0.0.1:3000"]) req = ctx.get("request") origin = req.headers.get("Origin", "") if req else "" if not allowed_origins or not origin: return next_fn() if origin in allowed_origins or "*" in allowed_origins: ctx["response_headers"] = { "Access-Control-Allow-Origin": origin, "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS", "Access-Control-Allow-Headers": "Content-Type, Authorization", "Access-Control-Allow-Credentials": "true", } return next_fn() class AuthMiddleware(Middleware): """鉴权中间件 - Bearer Token 认证""" @staticmethod def _get_public_paths() -> set: """获取公开路径白名单,优先从配置读取""" config = get_config() configured = config.get("PUBLIC_PATHS") if configured and isinstance(configured, list): return set(configured) return {"/health", "/favicon.ico", "/api/status", "/api/health"} def process(self, ctx: dict, next_fn: Callable) -> Optional[Response]: config = get_config() api_key = config.get("API_KEY") if not api_key: return next_fn() public_paths = self._get_public_paths() req = ctx.get("request") if req and req.path in public_paths: return next_fn() if req and req.method == "OPTIONS": return next_fn() auth_header = req.headers.get("Authorization", "") if req else "" token = auth_header.removeprefix("Bearer ").strip() if token != api_key or not token: Log.warn("Core", f"鉴权失败: {req.method} {req.path}" if req else "鉴权失败") return Response( status=401, body=json.dumps({"error": "Unauthorized", "message": "需要有效的 API Key"}), headers={"Content-Type": "application/json"}, ) return next_fn() class LoggerMiddleware(Middleware): """日志中间件""" _silent_paths = {"/api/dashboard/stats", "/favicon.ico", "/health"} def process(self, ctx: dict, next_fn: Callable) -> Optional[Response]: req = ctx.get("request") if req and req.path not in self._silent_paths: Log.info("Core", f"{req.method} {req.path}") return next_fn() class MiddlewareChain: """中间件链""" def __init__(self): self.middlewares: list[Middleware] = [] self.add(CorsMiddleware()) self.add(AuthMiddleware()) self.add(LoggerMiddleware()) self.add(RateLimitMiddleware()) def add(self, middleware: Middleware): self.middlewares.append(middleware) def run(self, ctx: dict[str, Any]) -> Optional[Response]: idx = 0 def next_fn(): nonlocal idx if idx < len(self.middlewares): mw = self.middlewares[idx] idx += 1 return mw.process(ctx, next_fn) return None resp = next_fn() response_headers = ctx.get("response_headers") if response_headers: ctx["_cors_headers"] = response_headers return resp