feat: Phase 1 - 安全中间件 + 运维工具箱
Some checks failed
CI / test (3.10) (push) Has been cancelled
CI / test (3.11) (push) Has been cancelled
CI / test (3.12) (push) Has been cancelled
CI / test (3.13) (push) Has been cancelled

新增 oss/core/security/ 模块(852行):
- jwt_auth.py: JWT签发/验证(HMAC-SHA256,零外部依赖)
- csrf.py: CSRF Token生成与校验
- input_validator.py: JSON Schema校验+类型强制
- tls.py: 自签名证书生成+SSL上下文

新增 oss/core/ops/ 模块:
- health.py: 增强版/health端点(CPU/内存/磁盘/运行时间)
- metrics.py: Prometheus兼容/metrics端点

对接改造:
- engine.py: 导出新模块
- manager.py: 注册/api/login /health /metrics路由
- middleware.py: CSRF+InputValidation中间件
- config.py: JWT_SECRET/CSRF_SECRET等配置项
- security.py→security/__init__.py: 合并插件沙箱与HTTP安全
This commit is contained in:
2026-05-17 15:42:40 +08:00
parent e67d2d8ef6
commit 5e957096fa
12 changed files with 754 additions and 56 deletions

View File

@@ -0,0 +1,272 @@
"""安全模块 — 插件沙箱 + HTTP 安全中间件
插件沙箱PluginProxy, IntegrityChecker, MemoryGuard, AuditLogger, TamperMonitor, FallbackManager
HTTP 安全JWT, CSRF, InputValidator, TLS
"""
# ── 插件沙箱(来自 oss/core/security.py ──
from __future__ import annotations
import threading
import hashlib
import time
import json
import functools
from pathlib import Path
from typing import Any, Optional, Callable, TYPE_CHECKING
from collections import deque
from oss.logger.logger import Log
if TYPE_CHECKING:
from oss.core.manager import PluginManager
class PluginPermissionError(Exception):
"""插件权限错误"""
pass
class PluginProxy:
"""插件代理 - 防止越级访问"""
def __init__(self, plugin_name: str, plugin_instance: Any, allowed_plugins: list[str], all_plugins: dict):
self._plugin_name = plugin_name
self._plugin_instance = plugin_instance
self._allowed_plugins = set(allowed_plugins)
self._all_plugins = all_plugins
def get_plugin(self, name: str) -> Any:
if name not in self._allowed_plugins and "*" not in self._allowed_plugins:
raise PluginPermissionError(f"插件 '{self._plugin_name}' 无权访问插件 '{name}'")
if name not in self._all_plugins:
return None
return self._all_plugins[name]["instance"]
def list_plugins(self) -> list[str]:
if "*" in self._allowed_plugins:
return list(self._all_plugins.keys())
return [n for n in self._allowed_plugins if n in self._all_plugins]
def get_capability(self, capability: str) -> Any:
return None
def __getattr__(self, name: str):
return getattr(self._plugin_instance, name)
class IntegrityChecker:
"""文件完整性检查"""
def __init__(self):
self._hashes: dict[str, str] = {}
def compute_hash(self, plugin_dir: Path) -> str:
hasher = hashlib.sha256()
for file_path in sorted(plugin_dir.rglob("*")):
if file_path.is_file() and "__pycache__" not in file_path.parts and file_path.name != "SIGNATURE":
rel_path = str(file_path.relative_to(plugin_dir))
hasher.update(rel_path.encode("utf-8"))
hasher.update(file_path.read_bytes())
return hasher.hexdigest()
def register(self, plugin_name: str, plugin_dir: Path):
self._hashes[plugin_name] = self.compute_hash(plugin_dir)
def verify(self, plugin_name: str, plugin_dir: Path) -> tuple[bool, str]:
if plugin_name not in self._hashes:
return False, f"插件 '{plugin_name}' 未注册完整性检查"
current = self.compute_hash(plugin_dir)
if current == self._hashes[plugin_name]:
return True, "完整性验证通过"
return False, f"文件 hash 不匹配,插件可能被篡改"
def get_hash(self, plugin_name: str) -> Optional[str]:
return self._hashes.get(plugin_name)
class MemoryGuard:
"""运行时内存保护 - 防止插件修改 Core 内部状态"""
FROZEN_ATTRS = {
"plugins", "capability_registry", "lifecycle_manager",
"dependency_resolver", "signature_verifier", "pl_injector",
"integrity_checker", "audit_logger", "tamper_monitor",
"fallback_manager", "http_server", "repl_shell",
}
def __init__(self, manager: PluginManager):
self._manager = manager
self._protected = True
def enable(self):
self._protected = True
def disable(self):
self._protected = False
def check_setattr(self, obj: Any, name: str, value: Any) -> bool:
if not self._protected:
return True
if obj is self._manager and name in self.FROZEN_ATTRS:
Log.warn("Core", f"内存防护: 阻止了对 Core 内部属性 '{name}' 的修改")
return False
return True
class AuditLogger:
"""插件行为审计"""
def __init__(self, max_logs: int = 1000):
self._logs: deque = deque(maxlen=max_logs)
self._enabled = True
def enable(self):
self._enabled = True
def disable(self):
self._enabled = False
def log(self, plugin_name: str, action: str, detail: str = ""):
if not self._enabled:
return
self._logs.append({"time": time.time(), "plugin": plugin_name, "action": action, "detail": detail})
def get_logs(self, plugin_name: str = None, limit: int = 50) -> list[dict]:
if plugin_name:
filtered = [log for log in self._logs if log["plugin"] == plugin_name]
else:
filtered = list(self._logs)
return filtered[-limit:]
def get_stats(self) -> dict:
stats: dict[str, int] = {}
for log in self._logs:
stats[log["plugin"]] = stats.get(log["plugin"], 0) + 1
return stats
class TamperMonitor:
"""防篡改监控"""
def __init__(self, manager: PluginManager, interval: int = 30):
self._manager = manager
self._interval = interval
self._running = False
self._thread = None
self._alerts: deque = deque(maxlen=100)
def start(self):
self._running = True
self._thread = threading.Thread(target=self._monitor_loop, daemon=True)
self._thread.start()
Log.info("Core", f"防篡改监控已启动 (间隔: {self._interval}s)")
def stop(self):
self._running = False
if self._thread:
self._thread.join(timeout=5)
def _monitor_loop(self):
while self._running:
try:
for plugin_name, info in self._manager.plugins.items():
plugin_dir = self._manager._get_plugin_dir(plugin_name)
if not plugin_dir:
continue
valid, msg = self._manager.integrity_checker.verify(plugin_name, plugin_dir)
if not valid:
alert = {"time": time.time(), "plugin": plugin_name, "message": msg}
self._alerts.append(alert)
Log.error("Core", f"防篡改告警: 插件 '{plugin_name}' 可能被篡改!")
try:
info["instance"].stop()
lifecycle = self._manager.lifecycle_manager.get(plugin_name)
if lifecycle:
lifecycle.mark_crashed()
except Exception as e:
Log.error("Core", f"停止被篡改插件 '{plugin_name}' 失败: {e}")
except Exception as e:
Log.error("Core", f"防篡改监控异常: {e}")
time.sleep(self._interval)
def get_alerts(self) -> list[dict]:
return list(self._alerts)
class FallbackManager:
"""降级恢复机制"""
def __init__(self, manager: PluginManager, max_retries: int = 3):
self._manager = manager
self._max_retries = max_retries
self._retry_counts: dict[str, int] = {}
self._degraded: set[str] = set()
def wrap_plugin_method(self, plugin_name: str, method: Callable) -> Callable:
@functools.wraps(method)
def safe_method(*args, **kwargs):
try:
return method(*args, **kwargs)
except Exception as e:
Log.error("Core", f"插件 '{plugin_name}' 方法 '{method.__name__}' 异常: {e}")
self._handle_crash(plugin_name)
return None
return safe_method
def _handle_crash(self, plugin_name: str):
retry_count = self._retry_counts.get(plugin_name, 0)
lifecycle = self._manager.lifecycle_manager.get(plugin_name)
bridge = self._manager._get_bridge()
if bridge and plugin_name != "plugin-bridge":
bridge.emit("plugin.crashed", name=plugin_name, retry=retry_count)
if retry_count < self._max_retries:
self._retry_counts[plugin_name] = retry_count + 1
Log.warn("Core", f"插件 '{plugin_name}' 崩溃,正在重启 (第 {retry_count + 1}/{self._max_retries} 次)")
try:
if lifecycle:
lifecycle.mark_crashed()
self._manager._restart_plugin(plugin_name)
if lifecycle:
lifecycle.start()
Log.ok("Core", f"插件 '{plugin_name}' 重启成功")
except Exception as e:
Log.error("Core", f"插件 '{plugin_name}' 重启失败: {e}")
else:
Log.error("Core", f"插件 '{plugin_name}' 超过最大重试次数,标记为降级")
self._degraded.add(plugin_name)
if lifecycle:
lifecycle.mark_degraded()
def recover(self, plugin_name: str) -> bool:
if plugin_name not in self._degraded:
return False
self._retry_counts[plugin_name] = 0
self._degraded.discard(plugin_name)
try:
self._manager._restart_plugin(plugin_name)
lifecycle = self._manager.lifecycle_manager.get(plugin_name)
if lifecycle:
lifecycle.start()
Log.ok("Core", f"插件 '{plugin_name}' 已手动恢复")
return True
except Exception as e:
Log.error("Core", f"恢复插件 '{plugin_name}' 失败: {e}")
return False
def is_degraded(self, plugin_name: str) -> bool:
return plugin_name in self._degraded
def get_degraded_plugins(self) -> list[str]:
return list(self._degraded)
# ── HTTP 安全中间件 ──
from .jwt_auth import JWTAuth, JWTError, issue_token, verify_token, get_jwt_auth
from .csrf import CSRFProtection
from .input_validator import InputValidator, ValidationError, get_validator
from .tls import TLSManager
__all__ = [
# 插件沙箱
"PluginPermissionError", "PluginProxy", "IntegrityChecker",
"MemoryGuard", "AuditLogger", "TamperMonitor", "FallbackManager",
# HTTP 安全
"JWTAuth", "JWTError", "issue_token", "verify_token", "get_jwt_auth",
"CSRFProtection",
"InputValidator", "ValidationError", "get_validator",
"TLSManager",
]

50
oss/core/security/csrf.py Normal file
View File

@@ -0,0 +1,50 @@
"""CSRF 防护 — Token 校验中间件"""
import secrets
import time
import hashlib
from typing import Optional
from oss.config import get_config
from oss.logger.logger import Log
class CSRFProtection:
"""CSRF Token 生成与验证"""
def __init__(self, secret: str = None):
config = get_config()
self._secret = secret or config.get("CSRF_SECRET", "")
if not self._secret:
self._secret = hashlib.sha256(config.get("API_KEY", "nebula-csrf-default").encode()).hexdigest()
self._token_ttl = config.get("CSRF_TOKEN_TTL", 3600) # 默认1小时
def generate_token(self, session_id: str) -> str:
"""生成 CSRF Token绑定 session"""
salt = secrets.token_hex(16)
timestamp = int(time.time())
raw = f"{session_id}:{salt}:{timestamp}:{self._secret}"
token = hashlib.sha256(raw.encode()).hexdigest()
return f"{timestamp}:{salt}:{token}"
def verify_token(self, session_id: str, token: str) -> bool:
"""验证 CSRF Token"""
try:
parts = token.split(":")
if len(parts) != 3:
return False
timestamp, salt, hash_val = parts
# 检查过期
if int(time.time()) - int(timestamp) > self._token_ttl:
return False
expected = hashlib.sha256(f"{session_id}:{salt}:{timestamp}:{self._secret}".encode()).hexdigest()
return hash_val == expected
except (ValueError, IndexError):
return False
SAFE_METHODS = {"GET", "HEAD", "OPTIONS"}
@staticmethod
def is_safe_method(method: str) -> bool:
return method.upper() in CSRFProtection.SAFE_METHODS

View File

@@ -0,0 +1,158 @@
"""输入验证 — JSON Schema 校验 + 参数白名单 + 类型强制"""
import json
import re
from typing import Any, Optional
from oss.logger.logger import Log
class ValidationError(Exception):
def __init__(self, message: str, field: str = None):
self.field = field
super().__init__(message)
class InputValidator:
"""输入验证器"""
# ── 内置类型校验器 ──
@staticmethod
def is_string(val: Any, min_len: int = 0, max_len: int = None) -> bool:
if not isinstance(val, str):
return False
if len(val) < min_len:
return False
if max_len and len(val) > max_len:
return False
return True
@staticmethod
def is_integer(val: Any, min_val: int = None, max_val: int = None) -> bool:
if not isinstance(val, int) or isinstance(val, bool):
return False
if min_val is not None and val < min_val:
return False
if max_val is not None and val > max_val:
return False
return True
@staticmethod
def is_float(val: Any, min_val: float = None, max_val: float = None) -> bool:
if not isinstance(val, (int, float)) or isinstance(val, bool):
return False
if min_val is not None and val < min_val:
return False
if max_val is not None and val > max_val:
return False
return True
@staticmethod
def is_boolean(val: Any) -> bool:
return isinstance(val, bool)
@staticmethod
def is_list(val: Any, item_type: type = None) -> bool:
if not isinstance(val, list):
return False
if item_type:
return all(isinstance(v, item_type) for v in val)
return True
@staticmethod
def is_dict(val: Any) -> bool:
return isinstance(val, dict)
@staticmethod
def is_email(val: str) -> bool:
return bool(re.match(r"^[^\s@]+@[^\s@]+\.[^\s@]+$", val))
@staticmethod
def is_ip_address(val: str) -> bool:
return bool(re.match(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$", val))
@staticmethod
def is_alphanumeric(val: str) -> bool:
return bool(re.match(r"^[a-zA-Z0-9_\-]+$", val))
# ── JSON Schema 校验 ──
def validate_schema(self, data: dict, schema: dict) -> list[str]:
"""根据 JSON Schema 校验数据,返回错误列表
Schema 格式:
{
"field_name": {
"type": "string|int|float|bool|list|dict|email|ip",
"required": True/False,
"min_len": 1, # 仅 string
"max_len": 100, # 仅 string
"min_val": 0, # 仅 int/float
"max_val": 100, # 仅 int/float
"pattern": "^...$", # 正则(仅 string
"default": "val", # 默认值(可选字段)
"items": "string", # list 元素类型
"fields": {...}, # 嵌套 dict schema
}
}
"""
errors = []
type_map = {
"string": lambda v, r: self.is_string(v, r.get("min_len", 0), r.get("max_len")),
"int": lambda v, r: self.is_integer(v, r.get("min_val"), r.get("max_val")),
"float": lambda v, r: self.is_float(v, r.get("min_val"), r.get("max_val")),
"bool": lambda v, _: self.is_boolean(v),
"list": lambda v, r: self.is_list(v),
"dict": lambda v, _: self.is_dict(v),
"email": lambda v, _: self.is_email(v),
"ip": lambda v, _: self.is_ip_address(v),
}
for field, rules in schema.items():
value = data.get(field)
# 必填检查
if rules.get("required", False):
if value is None:
errors.append(f"缺少必填字段: {field}")
continue
elif value is None:
continue
# 类型检查
expected_type = rules.get("type", "string")
checker = type_map.get(expected_type)
if checker:
if not checker(value, rules):
errors.append(f"字段 '{field}' 类型错误,期望 {expected_type}")
# 正则匹配string 类型)
pattern = rules.get("pattern")
if pattern and isinstance(value, str):
if not re.match(pattern, value):
errors.append(f"字段 '{field}' 格式不匹配: {pattern}")
# 嵌套 dict 校验
nested = rules.get("fields")
if nested and isinstance(value, dict):
errors.extend(self.validate_schema(value, nested))
return errors
# ── 快捷校验 ──
def validate_or_raise(self, data: dict, schema: dict):
"""校验失败抛出 ValidationError"""
errors = self.validate_schema(data, schema)
if errors:
raise ValidationError(errors[0])
_validator_instance: Optional[InputValidator] = None
def get_validator() -> InputValidator:
global _validator_instance
if _validator_instance is None:
_validator_instance = InputValidator()
return _validator_instance

View File

@@ -0,0 +1,106 @@
"""JWT 认证 — 签发/验证/中间件"""
import json
import time
import hashlib
import hmac as hmac_mod
import base64
from typing import Optional
from oss.config import get_config
from oss.logger.logger import Log
class JWTError(Exception):
pass
class JWTAuth:
"""JWT 签发与验证HMAC-SHA256无外部依赖"""
ALGORITHM = "HS256"
HEADER = base64.b64encode(json.dumps({"alg": "HS256", "typ": "JWT"}).encode()).rstrip(b"=").decode()
def __init__(self, secret: str = None):
config = get_config()
self._secret = secret or config.get("JWT_SECRET", "")
if not self._secret:
self._secret = hashlib.sha256(config.get("API_KEY", "nebula-default-secret").encode()).hexdigest()
@staticmethod
def _b64url(data: bytes) -> str:
return base64.urlsafe_b64encode(data).rstrip(b"=").decode()
@staticmethod
def _unb64url(data: str) -> bytes:
padding = 4 - len(data) % 4
if padding != 4:
data += "=" * padding
return base64.urlsafe_b64decode(data)
def _sign(self, payload_b64: str) -> str:
msg = f"{JWTAuth.HEADER}.{payload_b64}".encode()
sig = hmac_mod.new(self._secret.encode(), msg, hashlib.sha256).digest()
return self._b64url(sig)
def issue(self, user_id: str, role: str = "admin", expire_hours: int = 24) -> str:
"""签发 JWT Token"""
payload = {
"sub": user_id,
"role": role,
"iat": int(time.time()),
"exp": int(time.time()) + expire_hours * 3600,
}
payload_b64 = self._b64url(json.dumps(payload).encode())
signature = self._sign(payload_b64)
return f"{JWTAuth.HEADER}.{payload_b64}.{signature}"
def verify(self, token: str) -> Optional[dict]:
"""验证 JWT Token返回 payload 或 None"""
try:
parts = token.split(".")
if len(parts) != 3:
return None
header_b64, payload_b64, sig_b64 = parts
# 验签
expected_sig = self._sign(payload_b64)
if not hmac_mod.compare_digest(expected_sig, sig_b64):
return None
# 解码 payload
payload = json.loads(self._unb64url(payload_b64))
# 检查过期
if payload.get("exp", 0) < time.time():
return None
return payload
except Exception:
return None
@staticmethod
def extract_token(auth_header: str) -> Optional[str]:
"""从 Authorization 头提取 Bearer Token"""
if not auth_header or not auth_header.startswith("Bearer "):
return None
return auth_header[7:]
# ── 快捷方法 ──
_auth_instance: Optional[JWTAuth] = None
def get_jwt_auth() -> JWTAuth:
global _auth_instance
if _auth_instance is None:
_auth_instance = JWTAuth()
return _auth_instance
def issue_token(user_id: str, role: str = "admin") -> str:
return get_jwt_auth().issue(user_id, role)
def verify_token(token: str) -> Optional[dict]:
return get_jwt_auth().verify(token)

95
oss/core/security/tls.py Normal file
View File

@@ -0,0 +1,95 @@
"""HTTPS 支持 — 自签名证书生成 + TLS 上下文加载"""
import os
import datetime
from pathlib import Path
from typing import Optional
from cryptography import x509
from cryptography.x509.oid import NameOID
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.backends import default_backend
from oss.config import get_config
from oss.logger.logger import Log
class TLSManager:
"""TLS 证书管理"""
@staticmethod
def ensure_cert(cert_dir: str = None) -> tuple[str, str]:
"""确保证书存在,不存在则生成自签名证书
Returns:
(cert_path, key_path)
"""
config = get_config()
cert_dir = cert_dir or config.get("TLS_CERT_DIR", "./data/tls")
cert_path = Path(cert_dir) / "server.crt"
key_path = Path(cert_dir) / "server.key"
if cert_path.exists() and key_path.exists():
return str(cert_path), str(key_path)
Log.info("TLS", "生成自签名证书...")
cert_dir_path = Path(cert_dir)
cert_dir_path.mkdir(parents=True, exist_ok=True)
TLSManager._generate_self_signed(cert_path, key_path)
Log.ok("TLS", f"自签名证书已生成: {cert_path}")
return str(cert_path), str(key_path)
@staticmethod
def _generate_self_signed(cert_path: Path, key_path: Path):
"""生成自签名证书"""
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)
key_path.write_bytes(key.private_bytes(
serialization.Encoding.PEM,
serialization.PrivateFormat.TraditionalOpenSSL,
serialization.NoEncryption(),
))
subject = issuer = x509.Name([
x509.NameAttribute(NameOID.COUNTRY_NAME, "CN"),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "NebulaShell"),
x509.NameAttribute(NameOID.COMMON_NAME, "localhost"),
])
now = datetime.datetime.now(datetime.timezone.utc)
cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(now)
.not_valid_after(now + datetime.timedelta(days=365))
.add_extension(
x509.SubjectAlternativeName([
x509.DNSName("localhost"),
x509.DNSName("127.0.0.1"),
]),
critical=False,
)
.sign(key, hashes.SHA256(), default_backend())
)
cert_path.write_bytes(cert.public_bytes(serialization.Encoding.PEM))
@staticmethod
def create_ssl_context(cert_path: str = None, key_path: str = None) -> Optional[object]:
"""创建 SSL 上下文(用于 HTTPS 服务器)"""
try:
import ssl
cert, key = TLSManager.ensure_cert()
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ctx.load_cert_chain(
cert_path or cert,
key_path or key,
)
return ctx
except Exception as e:
Log.error("TLS", f"创建 SSL 上下文失败: {e}")
return None