""" 限流工具 - 令牌桶限流器 """ import time import threading from typing import Dict, Callable, Optional from collections import defaultdict, deque from oss.config import get_config from oss.core.http_api.server import Request, Response class RateLimiter: """令牌桶限流器""" def __init__(self, max_requests: int = 100, time_window: int = 60): self.max_requests = max_requests self.time_window = time_window self.requests: Dict[str, deque] = defaultdict(deque) self.lock = threading.Lock() def is_allowed(self, identifier: str) -> bool: """检查是否允许请求""" with self.lock: now = time.time() request_times = self.requests[identifier] # 清理过期的请求记录 while request_times and request_times[0] <= now - self.time_window: request_times.popleft() # 检查是否超过限制 if len(request_times) >= self.max_requests: return False # 记录当前请求 request_times.append(now) return True class RateLimitMiddleware: """限流中间件 - 防止DoS攻击""" def __init__(self): self.config = get_config() self.enabled = self.config.get("RATE_LIMIT_ENABLED", True) # 不同端点的限流配置 self.endpoint_limits = { "/api/dashboard/stats": { "max_requests": 10, "time_window": 60 }, } # 全局限流配置 self.global_limit = { "max_requests": self.config.get("RATE_LIMIT_MAX_REQUESTS", 100), "time_window": self.config.get("RATE_LIMIT_TIME_WINDOW", 60) } # 请求记录 self.requests = {} self.lock = threading.Lock() def _get_client_identifier(self, request: Request) -> str: """获取客户端标识符""" ip = request.headers.get("X-Forwarded-For", request.headers.get("X-Real-IP", "")) if not ip: ip = request.headers.get("Remote-Addr", "unknown") auth_header = request.headers.get("Authorization", "") if auth_header.startswith("Bearer "): return f"api_key:{auth_header[7:]}" return f"ip:{ip}" def _is_rate_limited(self, identifier: str, path: str) -> bool: """检查是否被限流""" if not self.enabled: return False now = time.time() limit_key = f"{identifier}:{path}" # 获取端点特定的限制 endpoint_limit = None for endpoint, config in self.endpoint_limits.items(): if path.startswith(endpoint): endpoint_limit = config break # 使用端点特定限制或全局限制 limit = endpoint_limit or self.global_limit max_requests = limit["max_requests"] time_window = limit["time_window"] with self.lock: if limit_key not in self.requests: self.requests[limit_key] = deque() request_times = self.requests[limit_key] while request_times and request_times[0] <= now - time_window: request_times.popleft() if len(request_times) >= max_requests: return True request_times.append(now) return False def _create_rate_limit_response(self) -> Response: """创建限流响应""" return Response( status=429, headers={ "Content-Type": "application/json", "Retry-After": str(self.global_limit["time_window"]), "X-Rate-Limit-Limit": str(self.global_limit["max_requests"]), "X-Rate-Limit-Window": str(self.global_limit["time_window"]), }, body='{"error": "Rate limit exceeded", "message": "请稍后再试"}' ) def process(self, ctx: dict, next_fn: Callable) -> Optional[Response]: """处理限流逻辑""" if not self.enabled: return next_fn() request = ctx.get("request") if not request: return next_fn() identifier = self._get_client_identifier(request) if self._is_rate_limited(identifier, request.path): return self._create_rate_limit_response() return next_fn()