Files
NebulaShell/oss/core/http_api/middleware.py
Starlight-apk 5e957096fa
Some checks failed
CI / test (3.10) (push) Has been cancelled
CI / test (3.11) (push) Has been cancelled
CI / test (3.12) (push) Has been cancelled
CI / test (3.13) (push) Has been cancelled
feat: Phase 1 - 安全中间件 + 运维工具箱
新增 oss/core/security/ 模块(852行):
- jwt_auth.py: JWT签发/验证(HMAC-SHA256,零外部依赖)
- csrf.py: CSRF Token生成与校验
- input_validator.py: JSON Schema校验+类型强制
- tls.py: 自签名证书生成+SSL上下文

新增 oss/core/ops/ 模块:
- health.py: 增强版/health端点(CPU/内存/磁盘/运行时间)
- metrics.py: Prometheus兼容/metrics端点

对接改造:
- engine.py: 导出新模块
- manager.py: 注册/api/login /health /metrics路由
- middleware.py: CSRF+InputValidation中间件
- config.py: JWT_SECRET/CSRF_SECRET等配置项
- security.py→security/__init__.py: 合并插件沙箱与HTTP安全
2026-05-17 15:42:40 +08:00

183 lines
6.2 KiB
Python

"""中间件链 - CORS/鉴权/日志/限流/CSRF/输入验证等"""
import json
import time
import threading
from collections import deque
from typing import Callable, Optional, Any
from oss.config import get_config
from oss.logger.logger import Log
from .server import Request, Response
from .rate_limiter import RateLimitMiddleware
class Middleware:
"""中间件基类"""
def process(self, ctx: dict[str, Any], next_fn: Callable) -> Optional[Response]:
return next_fn()
class CorsMiddleware(Middleware):
"""CORS 中间件"""
def process(self, ctx: dict, next_fn: Callable) -> Optional[Response]:
config = get_config()
allowed_origins = config.get("CORS_ALLOWED_ORIGINS", ["http://localhost:3000", "http://127.0.0.1:3000"])
req = ctx.get("request")
origin = req.headers.get("Origin", "") if req else ""
if not allowed_origins or not origin:
return next_fn()
if origin in allowed_origins or "*" in allowed_origins:
ctx["response_headers"] = {
"Access-Control-Allow-Origin": origin,
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization",
"Access-Control-Allow-Credentials": "true",
}
return next_fn()
class AuthMiddleware(Middleware):
"""鉴权中间件 - JWT + API_KEY 双模式认证"""
@staticmethod
def _get_public_paths() -> set:
"""获取公开路径白名单,优先从配置读取"""
config = get_config()
configured = config.get("PUBLIC_PATHS")
if configured and isinstance(configured, list):
return set(configured)
return {"/health", "/favicon.ico", "/api/status", "/api/health", "/api/login", "/metrics"}
def process(self, ctx: dict, next_fn: Callable) -> Optional[Response]:
config = get_config()
api_key = config.get("API_KEY", "")
public_paths = self._get_public_paths()
req = ctx.get("request")
if req and req.path in public_paths:
return next_fn()
if req and req.method == "OPTIONS":
return next_fn()
if not api_key:
# 无 API_KEY 时尝试 JWT 鉴权
auth_header = req.headers.get("Authorization", "") if req else ""
token = auth_header.removeprefix("Bearer ").strip()
if token:
from oss.core.security.jwt_auth import verify_token
payload = verify_token(token)
if payload:
ctx["user"] = payload
return next_fn()
return Response(
status=401,
body=json.dumps({"error": "Unauthorized", "message": "Token 无效或已过期"}),
headers={"Content-Type": "application/json"},
)
return next_fn()
# API_KEY 模式
auth_header = req.headers.get("Authorization", "") if req else ""
token = auth_header.removeprefix("Bearer ").strip()
if token == api_key and token:
return next_fn()
Log.warn("Core", f"鉴权失败: {req.method} {req.path}" if req else "鉴权失败")
return Response(
status=401,
body=json.dumps({"error": "Unauthorized", "message": "需要有效的认证凭据"}),
headers={"Content-Type": "application/json"},
)
class LoggerMiddleware(Middleware):
"""日志中间件"""
_silent_paths = {"/api/dashboard/stats", "/favicon.ico", "/health"}
def process(self, ctx: dict, next_fn: Callable) -> Optional[Response]:
req = ctx.get("request")
if req and req.path not in self._silent_paths:
Log.info("Core", f"{req.method} {req.path}")
return next_fn()
class CSRFMiddleware(Middleware):
"""CSRF 防护中间件"""
def process(self, ctx: dict, next_fn: Callable) -> Optional[Response]:
config = get_config()
if not config.get("CSRF_ENABLED", True):
return next_fn()
req = ctx.get("request")
if not req or CSRFProtection.is_safe_method(req.method):
return next_fn()
# 从 Header 或 Body 中获取 CSRF Token
token = req.headers.get("X-CSRF-Token", "")
session_id = req.headers.get("X-Session-Id", "")
if not token or not session_id:
return Response(
status=403,
body=json.dumps({"error": "Forbidden", "message": "缺少 CSRF Token"}),
headers={"Content-Type": "application/json"},
)
from oss.core.security.csrf import CSRFProtection
csrf = CSRFProtection()
if not csrf.verify_token(session_id, token):
return Response(
status=403,
body=json.dumps({"error": "Forbidden", "message": "CSRF Token 无效"}),
headers={"Content-Type": "application/json"},
)
return next_fn()
class InputValidationMiddleware(Middleware):
"""输入验证中间件"""
def process(self, ctx: dict, next_fn: Callable) -> Optional[Response]:
config = get_config()
if not config.get("INPUT_VALIDATION_ENABLED", True):
return next_fn()
return next_fn() # 具体 schema 校验在路由 handler 中按需调用
class MiddlewareChain:
"""中间件链"""
def __init__(self):
self.middlewares: list[Middleware] = []
self.add(CorsMiddleware())
self.add(AuthMiddleware())
self.add(CSRFMiddleware())
self.add(InputValidationMiddleware())
self.add(LoggerMiddleware())
self.add(RateLimitMiddleware())
def add(self, middleware: Middleware):
self.middlewares.append(middleware)
def run(self, ctx: dict[str, Any]) -> Optional[Response]:
idx = 0
def next_fn():
nonlocal idx
if idx < len(self.middlewares):
mw = self.middlewares[idx]
idx += 1
return mw.process(ctx, next_fn)
return None
resp = next_fn()
response_headers = ctx.get("response_headers")
if response_headers:
ctx["_cors_headers"] = response_headers
return resp