- 核心功能从 store/ 迁移至 oss/core/ 框架层 - 实现 NBPF 包格式:多重签名(Ed25519+RSA-PSS+HMAC)+ 多重加密(AES-256-GCM) - 实现 NIR 编译器:基于 compile()+marshal 的跨平台中间表示 - 新增 nebula nbpf CLI 命令组(pack/unpack/verify/sign/keygen) - 新增 19 个 NBPF 测试用例,覆盖全链路 - 彻底重写 README,大型项目标准框架风格,所有图表使用 SVG - 更新 LICENSE 版权声明 - 清理旧版 store 插件目录(已迁移至 oss/core)
139 lines
4.3 KiB
Python
139 lines
4.3 KiB
Python
"""
|
|
限流工具 - 令牌桶限流器
|
|
"""
|
|
import time
|
|
import threading
|
|
from typing import Dict, Callable, Optional
|
|
from collections import defaultdict, deque
|
|
|
|
from oss.config import get_config
|
|
from oss.core.http_api.server import Request, Response
|
|
|
|
|
|
class RateLimiter:
|
|
"""令牌桶限流器"""
|
|
|
|
def __init__(self, max_requests: int = 100, time_window: int = 60):
|
|
self.max_requests = max_requests
|
|
self.time_window = time_window
|
|
self.requests: Dict[str, deque] = defaultdict(deque)
|
|
self.lock = threading.Lock()
|
|
|
|
def is_allowed(self, identifier: str) -> bool:
|
|
"""检查是否允许请求"""
|
|
with self.lock:
|
|
now = time.time()
|
|
request_times = self.requests[identifier]
|
|
|
|
# 清理过期的请求记录
|
|
while request_times and request_times[0] <= now - self.time_window:
|
|
request_times.popleft()
|
|
|
|
# 检查是否超过限制
|
|
if len(request_times) >= self.max_requests:
|
|
return False
|
|
|
|
# 记录当前请求
|
|
request_times.append(now)
|
|
return True
|
|
|
|
|
|
class RateLimitMiddleware:
|
|
"""限流中间件 - 防止DoS攻击"""
|
|
def __init__(self):
|
|
self.config = get_config()
|
|
self.enabled = self.config.get("RATE_LIMIT_ENABLED", True)
|
|
|
|
# 不同端点的限流配置
|
|
self.endpoint_limits = {
|
|
"/api/dashboard/stats": {
|
|
"max_requests": 10,
|
|
"time_window": 60
|
|
},
|
|
}
|
|
|
|
# 全局限流配置
|
|
self.global_limit = {
|
|
"max_requests": self.config.get("RATE_LIMIT_MAX_REQUESTS", 100),
|
|
"time_window": self.config.get("RATE_LIMIT_TIME_WINDOW", 60)
|
|
}
|
|
|
|
# 请求记录
|
|
self.requests = {}
|
|
self.lock = threading.Lock()
|
|
|
|
def _get_client_identifier(self, request: Request) -> str:
|
|
"""获取客户端标识符"""
|
|
ip = request.headers.get("X-Forwarded-For", request.headers.get("X-Real-IP", ""))
|
|
if not ip:
|
|
ip = request.headers.get("Remote-Addr", "unknown")
|
|
|
|
auth_header = request.headers.get("Authorization", "")
|
|
if auth_header.startswith("Bearer "):
|
|
return f"api_key:{auth_header[7:]}"
|
|
|
|
return f"ip:{ip}"
|
|
|
|
def _is_rate_limited(self, identifier: str, path: str) -> bool:
|
|
"""检查是否被限流"""
|
|
if not self.enabled:
|
|
return False
|
|
|
|
now = time.time()
|
|
limit_key = f"{identifier}:{path}"
|
|
|
|
# 获取端点特定的限制
|
|
endpoint_limit = None
|
|
for endpoint, config in self.endpoint_limits.items():
|
|
if path.startswith(endpoint):
|
|
endpoint_limit = config
|
|
break
|
|
|
|
# 使用端点特定限制或全局限制
|
|
limit = endpoint_limit or self.global_limit
|
|
max_requests = limit["max_requests"]
|
|
time_window = limit["time_window"]
|
|
|
|
with self.lock:
|
|
if limit_key not in self.requests:
|
|
self.requests[limit_key] = deque()
|
|
|
|
request_times = self.requests[limit_key]
|
|
while request_times and request_times[0] <= now - time_window:
|
|
request_times.popleft()
|
|
|
|
if len(request_times) >= max_requests:
|
|
return True
|
|
|
|
request_times.append(now)
|
|
return False
|
|
|
|
def _create_rate_limit_response(self) -> Response:
|
|
"""创建限流响应"""
|
|
return Response(
|
|
status=429,
|
|
headers={
|
|
"Content-Type": "application/json",
|
|
"Retry-After": str(self.global_limit["time_window"]),
|
|
"X-Rate-Limit-Limit": str(self.global_limit["max_requests"]),
|
|
"X-Rate-Limit-Window": str(self.global_limit["time_window"]),
|
|
},
|
|
body='{"error": "Rate limit exceeded", "message": "请稍后再试"}'
|
|
)
|
|
|
|
def process(self, ctx: dict, next_fn: Callable) -> Optional[Response]:
|
|
"""处理限流逻辑"""
|
|
if not self.enabled:
|
|
return next_fn()
|
|
|
|
request = ctx.get("request")
|
|
if not request:
|
|
return next_fn()
|
|
|
|
identifier = self._get_client_identifier(request)
|
|
|
|
if self._is_rate_limited(identifier, request.path):
|
|
return self._create_rate_limit_response()
|
|
|
|
return next_fn()
|