feat: Phase 1 - 安全中间件 + 运维工具箱
新增 oss/core/security/ 模块(852行): - jwt_auth.py: JWT签发/验证(HMAC-SHA256,零外部依赖) - csrf.py: CSRF Token生成与校验 - input_validator.py: JSON Schema校验+类型强制 - tls.py: 自签名证书生成+SSL上下文 新增 oss/core/ops/ 模块: - health.py: 增强版/health端点(CPU/内存/磁盘/运行时间) - metrics.py: Prometheus兼容/metrics端点 对接改造: - engine.py: 导出新模块 - manager.py: 注册/api/login /health /metrics路由 - middleware.py: CSRF+InputValidation中间件 - config.py: JWT_SECRET/CSRF_SECRET等配置项 - security.py→security/__init__.py: 合并插件沙箱与HTTP安全
This commit is contained in:
272
oss/core/security/__init__.py
Normal file
272
oss/core/security/__init__.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""安全模块 — 插件沙箱 + HTTP 安全中间件
|
||||
|
||||
插件沙箱:PluginProxy, IntegrityChecker, MemoryGuard, AuditLogger, TamperMonitor, FallbackManager
|
||||
HTTP 安全:JWT, CSRF, InputValidator, TLS
|
||||
"""
|
||||
# ── 插件沙箱(来自 oss/core/security.py) ──
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import hashlib
|
||||
import time
|
||||
import json
|
||||
import functools
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Callable, TYPE_CHECKING
|
||||
from collections import deque
|
||||
|
||||
from oss.logger.logger import Log
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from oss.core.manager import PluginManager
|
||||
|
||||
|
||||
class PluginPermissionError(Exception):
|
||||
"""插件权限错误"""
|
||||
pass
|
||||
|
||||
|
||||
class PluginProxy:
|
||||
"""插件代理 - 防止越级访问"""
|
||||
def __init__(self, plugin_name: str, plugin_instance: Any, allowed_plugins: list[str], all_plugins: dict):
|
||||
self._plugin_name = plugin_name
|
||||
self._plugin_instance = plugin_instance
|
||||
self._allowed_plugins = set(allowed_plugins)
|
||||
self._all_plugins = all_plugins
|
||||
|
||||
def get_plugin(self, name: str) -> Any:
|
||||
if name not in self._allowed_plugins and "*" not in self._allowed_plugins:
|
||||
raise PluginPermissionError(f"插件 '{self._plugin_name}' 无权访问插件 '{name}'")
|
||||
if name not in self._all_plugins:
|
||||
return None
|
||||
return self._all_plugins[name]["instance"]
|
||||
|
||||
def list_plugins(self) -> list[str]:
|
||||
if "*" in self._allowed_plugins:
|
||||
return list(self._all_plugins.keys())
|
||||
return [n for n in self._allowed_plugins if n in self._all_plugins]
|
||||
|
||||
def get_capability(self, capability: str) -> Any:
|
||||
return None
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
return getattr(self._plugin_instance, name)
|
||||
|
||||
|
||||
class IntegrityChecker:
|
||||
"""文件完整性检查"""
|
||||
def __init__(self):
|
||||
self._hashes: dict[str, str] = {}
|
||||
|
||||
def compute_hash(self, plugin_dir: Path) -> str:
|
||||
hasher = hashlib.sha256()
|
||||
for file_path in sorted(plugin_dir.rglob("*")):
|
||||
if file_path.is_file() and "__pycache__" not in file_path.parts and file_path.name != "SIGNATURE":
|
||||
rel_path = str(file_path.relative_to(plugin_dir))
|
||||
hasher.update(rel_path.encode("utf-8"))
|
||||
hasher.update(file_path.read_bytes())
|
||||
return hasher.hexdigest()
|
||||
|
||||
def register(self, plugin_name: str, plugin_dir: Path):
|
||||
self._hashes[plugin_name] = self.compute_hash(plugin_dir)
|
||||
|
||||
def verify(self, plugin_name: str, plugin_dir: Path) -> tuple[bool, str]:
|
||||
if plugin_name not in self._hashes:
|
||||
return False, f"插件 '{plugin_name}' 未注册完整性检查"
|
||||
current = self.compute_hash(plugin_dir)
|
||||
if current == self._hashes[plugin_name]:
|
||||
return True, "完整性验证通过"
|
||||
return False, f"文件 hash 不匹配,插件可能被篡改"
|
||||
|
||||
def get_hash(self, plugin_name: str) -> Optional[str]:
|
||||
return self._hashes.get(plugin_name)
|
||||
|
||||
|
||||
class MemoryGuard:
|
||||
"""运行时内存保护 - 防止插件修改 Core 内部状态"""
|
||||
FROZEN_ATTRS = {
|
||||
"plugins", "capability_registry", "lifecycle_manager",
|
||||
"dependency_resolver", "signature_verifier", "pl_injector",
|
||||
"integrity_checker", "audit_logger", "tamper_monitor",
|
||||
"fallback_manager", "http_server", "repl_shell",
|
||||
}
|
||||
|
||||
def __init__(self, manager: PluginManager):
|
||||
self._manager = manager
|
||||
self._protected = True
|
||||
|
||||
def enable(self):
|
||||
self._protected = True
|
||||
|
||||
def disable(self):
|
||||
self._protected = False
|
||||
|
||||
def check_setattr(self, obj: Any, name: str, value: Any) -> bool:
|
||||
if not self._protected:
|
||||
return True
|
||||
if obj is self._manager and name in self.FROZEN_ATTRS:
|
||||
Log.warn("Core", f"内存防护: 阻止了对 Core 内部属性 '{name}' 的修改")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class AuditLogger:
|
||||
"""插件行为审计"""
|
||||
def __init__(self, max_logs: int = 1000):
|
||||
self._logs: deque = deque(maxlen=max_logs)
|
||||
self._enabled = True
|
||||
|
||||
def enable(self):
|
||||
self._enabled = True
|
||||
|
||||
def disable(self):
|
||||
self._enabled = False
|
||||
|
||||
def log(self, plugin_name: str, action: str, detail: str = ""):
|
||||
if not self._enabled:
|
||||
return
|
||||
self._logs.append({"time": time.time(), "plugin": plugin_name, "action": action, "detail": detail})
|
||||
|
||||
def get_logs(self, plugin_name: str = None, limit: int = 50) -> list[dict]:
|
||||
if plugin_name:
|
||||
filtered = [log for log in self._logs if log["plugin"] == plugin_name]
|
||||
else:
|
||||
filtered = list(self._logs)
|
||||
return filtered[-limit:]
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
stats: dict[str, int] = {}
|
||||
for log in self._logs:
|
||||
stats[log["plugin"]] = stats.get(log["plugin"], 0) + 1
|
||||
return stats
|
||||
|
||||
|
||||
class TamperMonitor:
|
||||
"""防篡改监控"""
|
||||
def __init__(self, manager: PluginManager, interval: int = 30):
|
||||
self._manager = manager
|
||||
self._interval = interval
|
||||
self._running = False
|
||||
self._thread = None
|
||||
self._alerts: deque = deque(maxlen=100)
|
||||
|
||||
def start(self):
|
||||
self._running = True
|
||||
self._thread = threading.Thread(target=self._monitor_loop, daemon=True)
|
||||
self._thread.start()
|
||||
Log.info("Core", f"防篡改监控已启动 (间隔: {self._interval}s)")
|
||||
|
||||
def stop(self):
|
||||
self._running = False
|
||||
if self._thread:
|
||||
self._thread.join(timeout=5)
|
||||
|
||||
def _monitor_loop(self):
|
||||
while self._running:
|
||||
try:
|
||||
for plugin_name, info in self._manager.plugins.items():
|
||||
plugin_dir = self._manager._get_plugin_dir(plugin_name)
|
||||
if not plugin_dir:
|
||||
continue
|
||||
valid, msg = self._manager.integrity_checker.verify(plugin_name, plugin_dir)
|
||||
if not valid:
|
||||
alert = {"time": time.time(), "plugin": plugin_name, "message": msg}
|
||||
self._alerts.append(alert)
|
||||
Log.error("Core", f"防篡改告警: 插件 '{plugin_name}' 可能被篡改!")
|
||||
try:
|
||||
info["instance"].stop()
|
||||
lifecycle = self._manager.lifecycle_manager.get(plugin_name)
|
||||
if lifecycle:
|
||||
lifecycle.mark_crashed()
|
||||
except Exception as e:
|
||||
Log.error("Core", f"停止被篡改插件 '{plugin_name}' 失败: {e}")
|
||||
except Exception as e:
|
||||
Log.error("Core", f"防篡改监控异常: {e}")
|
||||
time.sleep(self._interval)
|
||||
|
||||
def get_alerts(self) -> list[dict]:
|
||||
return list(self._alerts)
|
||||
|
||||
|
||||
class FallbackManager:
|
||||
"""降级恢复机制"""
|
||||
def __init__(self, manager: PluginManager, max_retries: int = 3):
|
||||
self._manager = manager
|
||||
self._max_retries = max_retries
|
||||
self._retry_counts: dict[str, int] = {}
|
||||
self._degraded: set[str] = set()
|
||||
|
||||
def wrap_plugin_method(self, plugin_name: str, method: Callable) -> Callable:
|
||||
@functools.wraps(method)
|
||||
def safe_method(*args, **kwargs):
|
||||
try:
|
||||
return method(*args, **kwargs)
|
||||
except Exception as e:
|
||||
Log.error("Core", f"插件 '{plugin_name}' 方法 '{method.__name__}' 异常: {e}")
|
||||
self._handle_crash(plugin_name)
|
||||
return None
|
||||
return safe_method
|
||||
|
||||
def _handle_crash(self, plugin_name: str):
|
||||
retry_count = self._retry_counts.get(plugin_name, 0)
|
||||
lifecycle = self._manager.lifecycle_manager.get(plugin_name)
|
||||
bridge = self._manager._get_bridge()
|
||||
if bridge and plugin_name != "plugin-bridge":
|
||||
bridge.emit("plugin.crashed", name=plugin_name, retry=retry_count)
|
||||
if retry_count < self._max_retries:
|
||||
self._retry_counts[plugin_name] = retry_count + 1
|
||||
Log.warn("Core", f"插件 '{plugin_name}' 崩溃,正在重启 (第 {retry_count + 1}/{self._max_retries} 次)")
|
||||
try:
|
||||
if lifecycle:
|
||||
lifecycle.mark_crashed()
|
||||
self._manager._restart_plugin(plugin_name)
|
||||
if lifecycle:
|
||||
lifecycle.start()
|
||||
Log.ok("Core", f"插件 '{plugin_name}' 重启成功")
|
||||
except Exception as e:
|
||||
Log.error("Core", f"插件 '{plugin_name}' 重启失败: {e}")
|
||||
else:
|
||||
Log.error("Core", f"插件 '{plugin_name}' 超过最大重试次数,标记为降级")
|
||||
self._degraded.add(plugin_name)
|
||||
if lifecycle:
|
||||
lifecycle.mark_degraded()
|
||||
|
||||
def recover(self, plugin_name: str) -> bool:
|
||||
if plugin_name not in self._degraded:
|
||||
return False
|
||||
self._retry_counts[plugin_name] = 0
|
||||
self._degraded.discard(plugin_name)
|
||||
try:
|
||||
self._manager._restart_plugin(plugin_name)
|
||||
lifecycle = self._manager.lifecycle_manager.get(plugin_name)
|
||||
if lifecycle:
|
||||
lifecycle.start()
|
||||
Log.ok("Core", f"插件 '{plugin_name}' 已手动恢复")
|
||||
return True
|
||||
except Exception as e:
|
||||
Log.error("Core", f"恢复插件 '{plugin_name}' 失败: {e}")
|
||||
return False
|
||||
|
||||
def is_degraded(self, plugin_name: str) -> bool:
|
||||
return plugin_name in self._degraded
|
||||
|
||||
def get_degraded_plugins(self) -> list[str]:
|
||||
return list(self._degraded)
|
||||
|
||||
|
||||
# ── HTTP 安全中间件 ──
|
||||
from .jwt_auth import JWTAuth, JWTError, issue_token, verify_token, get_jwt_auth
|
||||
from .csrf import CSRFProtection
|
||||
from .input_validator import InputValidator, ValidationError, get_validator
|
||||
from .tls import TLSManager
|
||||
|
||||
__all__ = [
|
||||
# 插件沙箱
|
||||
"PluginPermissionError", "PluginProxy", "IntegrityChecker",
|
||||
"MemoryGuard", "AuditLogger", "TamperMonitor", "FallbackManager",
|
||||
# HTTP 安全
|
||||
"JWTAuth", "JWTError", "issue_token", "verify_token", "get_jwt_auth",
|
||||
"CSRFProtection",
|
||||
"InputValidator", "ValidationError", "get_validator",
|
||||
"TLSManager",
|
||||
]
|
||||
50
oss/core/security/csrf.py
Normal file
50
oss/core/security/csrf.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""CSRF 防护 — Token 校验中间件"""
|
||||
import secrets
|
||||
import time
|
||||
import hashlib
|
||||
from typing import Optional
|
||||
|
||||
from oss.config import get_config
|
||||
from oss.logger.logger import Log
|
||||
|
||||
|
||||
class CSRFProtection:
|
||||
"""CSRF Token 生成与验证"""
|
||||
|
||||
def __init__(self, secret: str = None):
|
||||
config = get_config()
|
||||
self._secret = secret or config.get("CSRF_SECRET", "")
|
||||
if not self._secret:
|
||||
self._secret = hashlib.sha256(config.get("API_KEY", "nebula-csrf-default").encode()).hexdigest()
|
||||
self._token_ttl = config.get("CSRF_TOKEN_TTL", 3600) # 默认1小时
|
||||
|
||||
def generate_token(self, session_id: str) -> str:
|
||||
"""生成 CSRF Token(绑定 session)"""
|
||||
salt = secrets.token_hex(16)
|
||||
timestamp = int(time.time())
|
||||
raw = f"{session_id}:{salt}:{timestamp}:{self._secret}"
|
||||
token = hashlib.sha256(raw.encode()).hexdigest()
|
||||
return f"{timestamp}:{salt}:{token}"
|
||||
|
||||
def verify_token(self, session_id: str, token: str) -> bool:
|
||||
"""验证 CSRF Token"""
|
||||
try:
|
||||
parts = token.split(":")
|
||||
if len(parts) != 3:
|
||||
return False
|
||||
timestamp, salt, hash_val = parts
|
||||
|
||||
# 检查过期
|
||||
if int(time.time()) - int(timestamp) > self._token_ttl:
|
||||
return False
|
||||
|
||||
expected = hashlib.sha256(f"{session_id}:{salt}:{timestamp}:{self._secret}".encode()).hexdigest()
|
||||
return hash_val == expected
|
||||
except (ValueError, IndexError):
|
||||
return False
|
||||
|
||||
SAFE_METHODS = {"GET", "HEAD", "OPTIONS"}
|
||||
|
||||
@staticmethod
|
||||
def is_safe_method(method: str) -> bool:
|
||||
return method.upper() in CSRFProtection.SAFE_METHODS
|
||||
158
oss/core/security/input_validator.py
Normal file
158
oss/core/security/input_validator.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""输入验证 — JSON Schema 校验 + 参数白名单 + 类型强制"""
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Optional
|
||||
|
||||
from oss.logger.logger import Log
|
||||
|
||||
|
||||
class ValidationError(Exception):
|
||||
def __init__(self, message: str, field: str = None):
|
||||
self.field = field
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class InputValidator:
|
||||
"""输入验证器"""
|
||||
|
||||
# ── 内置类型校验器 ──
|
||||
|
||||
@staticmethod
|
||||
def is_string(val: Any, min_len: int = 0, max_len: int = None) -> bool:
|
||||
if not isinstance(val, str):
|
||||
return False
|
||||
if len(val) < min_len:
|
||||
return False
|
||||
if max_len and len(val) > max_len:
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def is_integer(val: Any, min_val: int = None, max_val: int = None) -> bool:
|
||||
if not isinstance(val, int) or isinstance(val, bool):
|
||||
return False
|
||||
if min_val is not None and val < min_val:
|
||||
return False
|
||||
if max_val is not None and val > max_val:
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def is_float(val: Any, min_val: float = None, max_val: float = None) -> bool:
|
||||
if not isinstance(val, (int, float)) or isinstance(val, bool):
|
||||
return False
|
||||
if min_val is not None and val < min_val:
|
||||
return False
|
||||
if max_val is not None and val > max_val:
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def is_boolean(val: Any) -> bool:
|
||||
return isinstance(val, bool)
|
||||
|
||||
@staticmethod
|
||||
def is_list(val: Any, item_type: type = None) -> bool:
|
||||
if not isinstance(val, list):
|
||||
return False
|
||||
if item_type:
|
||||
return all(isinstance(v, item_type) for v in val)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def is_dict(val: Any) -> bool:
|
||||
return isinstance(val, dict)
|
||||
|
||||
@staticmethod
|
||||
def is_email(val: str) -> bool:
|
||||
return bool(re.match(r"^[^\s@]+@[^\s@]+\.[^\s@]+$", val))
|
||||
|
||||
@staticmethod
|
||||
def is_ip_address(val: str) -> bool:
|
||||
return bool(re.match(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$", val))
|
||||
|
||||
@staticmethod
|
||||
def is_alphanumeric(val: str) -> bool:
|
||||
return bool(re.match(r"^[a-zA-Z0-9_\-]+$", val))
|
||||
|
||||
# ── JSON Schema 校验 ──
|
||||
|
||||
def validate_schema(self, data: dict, schema: dict) -> list[str]:
|
||||
"""根据 JSON Schema 校验数据,返回错误列表
|
||||
|
||||
Schema 格式:
|
||||
{
|
||||
"field_name": {
|
||||
"type": "string|int|float|bool|list|dict|email|ip",
|
||||
"required": True/False,
|
||||
"min_len": 1, # 仅 string
|
||||
"max_len": 100, # 仅 string
|
||||
"min_val": 0, # 仅 int/float
|
||||
"max_val": 100, # 仅 int/float
|
||||
"pattern": "^...$", # 正则(仅 string)
|
||||
"default": "val", # 默认值(可选字段)
|
||||
"items": "string", # list 元素类型
|
||||
"fields": {...}, # 嵌套 dict schema
|
||||
}
|
||||
}
|
||||
"""
|
||||
errors = []
|
||||
type_map = {
|
||||
"string": lambda v, r: self.is_string(v, r.get("min_len", 0), r.get("max_len")),
|
||||
"int": lambda v, r: self.is_integer(v, r.get("min_val"), r.get("max_val")),
|
||||
"float": lambda v, r: self.is_float(v, r.get("min_val"), r.get("max_val")),
|
||||
"bool": lambda v, _: self.is_boolean(v),
|
||||
"list": lambda v, r: self.is_list(v),
|
||||
"dict": lambda v, _: self.is_dict(v),
|
||||
"email": lambda v, _: self.is_email(v),
|
||||
"ip": lambda v, _: self.is_ip_address(v),
|
||||
}
|
||||
|
||||
for field, rules in schema.items():
|
||||
value = data.get(field)
|
||||
|
||||
# 必填检查
|
||||
if rules.get("required", False):
|
||||
if value is None:
|
||||
errors.append(f"缺少必填字段: {field}")
|
||||
continue
|
||||
elif value is None:
|
||||
continue
|
||||
|
||||
# 类型检查
|
||||
expected_type = rules.get("type", "string")
|
||||
checker = type_map.get(expected_type)
|
||||
if checker:
|
||||
if not checker(value, rules):
|
||||
errors.append(f"字段 '{field}' 类型错误,期望 {expected_type}")
|
||||
|
||||
# 正则匹配(string 类型)
|
||||
pattern = rules.get("pattern")
|
||||
if pattern and isinstance(value, str):
|
||||
if not re.match(pattern, value):
|
||||
errors.append(f"字段 '{field}' 格式不匹配: {pattern}")
|
||||
|
||||
# 嵌套 dict 校验
|
||||
nested = rules.get("fields")
|
||||
if nested and isinstance(value, dict):
|
||||
errors.extend(self.validate_schema(value, nested))
|
||||
|
||||
return errors
|
||||
|
||||
# ── 快捷校验 ──
|
||||
|
||||
def validate_or_raise(self, data: dict, schema: dict):
|
||||
"""校验失败抛出 ValidationError"""
|
||||
errors = self.validate_schema(data, schema)
|
||||
if errors:
|
||||
raise ValidationError(errors[0])
|
||||
|
||||
|
||||
_validator_instance: Optional[InputValidator] = None
|
||||
|
||||
|
||||
def get_validator() -> InputValidator:
|
||||
global _validator_instance
|
||||
if _validator_instance is None:
|
||||
_validator_instance = InputValidator()
|
||||
return _validator_instance
|
||||
106
oss/core/security/jwt_auth.py
Normal file
106
oss/core/security/jwt_auth.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""JWT 认证 — 签发/验证/中间件"""
|
||||
import json
|
||||
import time
|
||||
import hashlib
|
||||
import hmac as hmac_mod
|
||||
import base64
|
||||
from typing import Optional
|
||||
|
||||
from oss.config import get_config
|
||||
from oss.logger.logger import Log
|
||||
|
||||
|
||||
class JWTError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class JWTAuth:
|
||||
"""JWT 签发与验证(HMAC-SHA256,无外部依赖)"""
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
HEADER = base64.b64encode(json.dumps({"alg": "HS256", "typ": "JWT"}).encode()).rstrip(b"=").decode()
|
||||
|
||||
def __init__(self, secret: str = None):
|
||||
config = get_config()
|
||||
self._secret = secret or config.get("JWT_SECRET", "")
|
||||
if not self._secret:
|
||||
self._secret = hashlib.sha256(config.get("API_KEY", "nebula-default-secret").encode()).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _b64url(data: bytes) -> str:
|
||||
return base64.urlsafe_b64encode(data).rstrip(b"=").decode()
|
||||
|
||||
@staticmethod
|
||||
def _unb64url(data: str) -> bytes:
|
||||
padding = 4 - len(data) % 4
|
||||
if padding != 4:
|
||||
data += "=" * padding
|
||||
return base64.urlsafe_b64decode(data)
|
||||
|
||||
def _sign(self, payload_b64: str) -> str:
|
||||
msg = f"{JWTAuth.HEADER}.{payload_b64}".encode()
|
||||
sig = hmac_mod.new(self._secret.encode(), msg, hashlib.sha256).digest()
|
||||
return self._b64url(sig)
|
||||
|
||||
def issue(self, user_id: str, role: str = "admin", expire_hours: int = 24) -> str:
|
||||
"""签发 JWT Token"""
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"role": role,
|
||||
"iat": int(time.time()),
|
||||
"exp": int(time.time()) + expire_hours * 3600,
|
||||
}
|
||||
payload_b64 = self._b64url(json.dumps(payload).encode())
|
||||
signature = self._sign(payload_b64)
|
||||
return f"{JWTAuth.HEADER}.{payload_b64}.{signature}"
|
||||
|
||||
def verify(self, token: str) -> Optional[dict]:
|
||||
"""验证 JWT Token,返回 payload 或 None"""
|
||||
try:
|
||||
parts = token.split(".")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
header_b64, payload_b64, sig_b64 = parts
|
||||
|
||||
# 验签
|
||||
expected_sig = self._sign(payload_b64)
|
||||
if not hmac_mod.compare_digest(expected_sig, sig_b64):
|
||||
return None
|
||||
|
||||
# 解码 payload
|
||||
payload = json.loads(self._unb64url(payload_b64))
|
||||
|
||||
# 检查过期
|
||||
if payload.get("exp", 0) < time.time():
|
||||
return None
|
||||
|
||||
return payload
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def extract_token(auth_header: str) -> Optional[str]:
|
||||
"""从 Authorization 头提取 Bearer Token"""
|
||||
if not auth_header or not auth_header.startswith("Bearer "):
|
||||
return None
|
||||
return auth_header[7:]
|
||||
|
||||
|
||||
# ── 快捷方法 ──
|
||||
|
||||
_auth_instance: Optional[JWTAuth] = None
|
||||
|
||||
|
||||
def get_jwt_auth() -> JWTAuth:
|
||||
global _auth_instance
|
||||
if _auth_instance is None:
|
||||
_auth_instance = JWTAuth()
|
||||
return _auth_instance
|
||||
|
||||
|
||||
def issue_token(user_id: str, role: str = "admin") -> str:
|
||||
return get_jwt_auth().issue(user_id, role)
|
||||
|
||||
|
||||
def verify_token(token: str) -> Optional[dict]:
|
||||
return get_jwt_auth().verify(token)
|
||||
95
oss/core/security/tls.py
Normal file
95
oss/core/security/tls.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""HTTPS 支持 — 自签名证书生成 + TLS 上下文加载"""
|
||||
import os
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from cryptography import x509
|
||||
from cryptography.x509.oid import NameOID
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
from oss.config import get_config
|
||||
from oss.logger.logger import Log
|
||||
|
||||
|
||||
class TLSManager:
|
||||
"""TLS 证书管理"""
|
||||
|
||||
@staticmethod
|
||||
def ensure_cert(cert_dir: str = None) -> tuple[str, str]:
|
||||
"""确保证书存在,不存在则生成自签名证书
|
||||
|
||||
Returns:
|
||||
(cert_path, key_path)
|
||||
"""
|
||||
config = get_config()
|
||||
cert_dir = cert_dir or config.get("TLS_CERT_DIR", "./data/tls")
|
||||
cert_path = Path(cert_dir) / "server.crt"
|
||||
key_path = Path(cert_dir) / "server.key"
|
||||
|
||||
if cert_path.exists() and key_path.exists():
|
||||
return str(cert_path), str(key_path)
|
||||
|
||||
Log.info("TLS", "生成自签名证书...")
|
||||
cert_dir_path = Path(cert_dir)
|
||||
cert_dir_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
TLSManager._generate_self_signed(cert_path, key_path)
|
||||
Log.ok("TLS", f"自签名证书已生成: {cert_path}")
|
||||
return str(cert_path), str(key_path)
|
||||
|
||||
@staticmethod
|
||||
def _generate_self_signed(cert_path: Path, key_path: Path):
|
||||
"""生成自签名证书"""
|
||||
key = rsa.generate_private_key(
|
||||
public_exponent=65537, key_size=2048, backend=default_backend()
|
||||
)
|
||||
key_path.write_bytes(key.private_bytes(
|
||||
serialization.Encoding.PEM,
|
||||
serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
serialization.NoEncryption(),
|
||||
))
|
||||
|
||||
subject = issuer = x509.Name([
|
||||
x509.NameAttribute(NameOID.COUNTRY_NAME, "CN"),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "NebulaShell"),
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, "localhost"),
|
||||
])
|
||||
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
cert = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(subject)
|
||||
.issuer_name(issuer)
|
||||
.public_key(key.public_key())
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(now)
|
||||
.not_valid_after(now + datetime.timedelta(days=365))
|
||||
.add_extension(
|
||||
x509.SubjectAlternativeName([
|
||||
x509.DNSName("localhost"),
|
||||
x509.DNSName("127.0.0.1"),
|
||||
]),
|
||||
critical=False,
|
||||
)
|
||||
.sign(key, hashes.SHA256(), default_backend())
|
||||
)
|
||||
cert_path.write_bytes(cert.public_bytes(serialization.Encoding.PEM))
|
||||
|
||||
@staticmethod
|
||||
def create_ssl_context(cert_path: str = None, key_path: str = None) -> Optional[object]:
|
||||
"""创建 SSL 上下文(用于 HTTPS 服务器)"""
|
||||
try:
|
||||
import ssl
|
||||
cert, key = TLSManager.ensure_cert()
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
ctx.load_cert_chain(
|
||||
cert_path or cert,
|
||||
key_path or key,
|
||||
)
|
||||
return ctx
|
||||
except Exception as e:
|
||||
Log.error("TLS", f"创建 SSL 上下文失败: {e}")
|
||||
return None
|
||||
Reference in New Issue
Block a user