Files
NebulaShell/tests/test_security_improvements.py
Falck f5c659b665 🔧 修复P0级问题:40+文件语法错误 + import路径 + 清理废弃代码
 跟项目能跑起来就差这一步!这次狠狠修了一波:

🩺 修复40+损坏Python文件
   - 补全所有缺少的class定义头(plugin-loader-pro、code-reviewer、
     http-api/ws-api/http-tcp、webui/dashboard/log-terminal 等)
   - 修复中文括号、字符串未闭合、缩进错乱等语法问题

🔗 创建符号链接 plugin_bridge -> plugin-bridge
   - 解决Python模块路径不支持连字符的问题
   - 关联修复 plugin-bridge 中错误的 import 路径

🧹 清理废弃代码
   - 删除 oss/tui/ 目录(已废弃)
   - 清理所有 __pycache__ 和 .pyc 缓存文件

 全量语法检查通过,零错误!
📋 ai.md 新增代码审计报告和分阶段修复计划
🗺️ 所有插件 use() 调用现在走统一路径
2026-05-03 09:26:47 +08:00

289 lines
8.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
安全改进验证测试
"""
import sys
import json
import importlib.util
from pathlib import Path
# 添加项目根目录到路径
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
# 添加store目录到路径
store_path = project_root / "store"
sys.path.insert(0, str(store_path))
from oss.config import Config
from oss.logger.logger import Logger
def test_security_configurations():
"""测试安全配置"""
print("=== 测试安全配置 ===")
config = Config()
# 测试CORS配置
cors_origins = config.get("CORS_ALLOWED_ORIGINS")
print(f"✅ CORS配置: {cors_origins}")
# 测试HOST配置
host = config.get("HOST")
print(f"✅ HOST配置: {host}")
# 测试限流配置
rate_limit_enabled = config.get("RATE_LIMIT_ENABLED")
rate_limit_max = config.get("RATE_LIMIT_MAX_REQUESTS")
rate_limit_window = config.get("RATE_LIMIT_TIME_WINDOW")
print(f"✅ 限流配置: {rate_limit_enabled}, {rate_limit_max}/分钟")
# 测试CSRF配置
csrf_enabled = config.get("CSRF_ENABLED")
print(f"✅ CSRF配置: {csrf_enabled}")
# 测试输入验证配置
input_validation_enabled = config.get("INPUT_VALIDATION_ENABLED")
print(f"✅ 输入验证配置: {input_validation_enabled}")
# 测试API密钥配置
api_key = config.get("API_KEY")
print(f"✅ API密钥配置: {'已配置' if api_key else '未配置'}")
return True
def dynamic_import(module_path, class_name):
spec = importlib.util.spec_from_file_location("module", module_path)
module = importlib.util.module_from_spec(spec)
sys.modules["module"] = module
spec.loader.exec_module(module)
return getattr(module, class_name)
def test_rate_limiting():
"""测试限流功能"""
print("\n=== 测试限流功能 ===")
try:
rate_limiter_path = str(project_root / "store" / "NebulaShell" / "http-api" / "rate_limiter.py")
RateLimitMiddleware = dynamic_import(rate_limiter_path, "RateLimitMiddleware")
middleware = RateLimitMiddleware()
# 创建模拟请求
class MockRequest:
def __init__(self, path="/api/test"):
self.path = path
self.headers = {"Remote-Addr": "127.0.0.1"}
ctx = {"request": MockRequest()}
# 测试正常请求
result = middleware.process(ctx, lambda: None)
print("✅ 限流中间件正常工作")
return True
except Exception as e:
print(f"❌ 限流测试失败: {e}")
return False
def test_csrf_protection():
"""测试CSRF防护功能"""
print("\n=== 测试CSRF防护功能 ===")
try:
csrf_path = str(project_root / "store" / "NebulaShell" / "http-api" / "csrf_middleware.py")
CsrfMiddleware = dynamic_import(csrf_path, "CsrfMiddleware")
middleware = CsrfMiddleware()
# 创建模拟请求
class MockRequest:
def __init__(self, method="GET", path="/api/test"):
self.method = method
self.path = path
self.headers = {"Remote-Addr": "127.0.0.1"}
ctx = {"request": MockRequest()}
# 测试GET请求应该通过
result = middleware.process(ctx, lambda: None)
print("✅ CSRF防护中间件正常工作")
return True
except Exception as e:
print(f"❌ CSRF测试失败: {e}")
return False
def test_input_validation():
"""测试输入验证功能"""
print("\n=== 测试输入验证功能 ===")
try:
input_validation_path = str(project_root / "store" / "NebulaShell" / "http-api" / "input_validation.py")
InputValidationMiddleware = dynamic_import(input_validation_path, "InputValidationMiddleware")
middleware = InputValidationMiddleware()
# 创建模拟请求
class MockRequest:
def __init__(self, method="GET", path="/api/test", body=None):
self.method = method
self.path = path
self.body = body or ""
self.headers = {}
ctx = {"request": MockRequest()}
# 测试正常请求
result = middleware.process(ctx, lambda: None)
print("✅ 输入验证中间件正常工作")
return True
except Exception as e:
print(f"❌ 输入验证测试失败: {e}")
return False
def test_middleware_chain():
"""测试中间件链"""
print("\n=== 测试中间件链 ===")
try:
middleware_path = str(project_root / "store" / "NebulaShell" / "http-api" / "middleware.py")
MiddlewareChain = dynamic_import(middleware_path, "MiddlewareChain")
chain = MiddlewareChain()
print("✅ 中间件链创建成功")
# 检查中间件数量
print(f"✅ 中间件数量: {len(chain.middlewares)}")
# 检查包含的中间件
middleware_names = [type(m).__name__ for m in chain.middlewares]
print(f"✅ 中间件列表: {middleware_names}")
return True
except Exception as e:
print(f"❌ 中间件链测试失败: {e}")
return False
def test_security_headers():
"""测试安全头设置"""
print("\n=== 测试安全头设置 ===")
try:
CorsMiddleware = dynamic_import(middleware_path, "CorsMiddleware")
middleware = CorsMiddleware()
# 创建模拟请求
class MockRequest:
def __init__(self, origin="http://localhost:3000"):
self.headers = {"Origin": origin}
ctx = {"request": MockRequest()}
# 测试CORS头设置
result = middleware.process(ctx, lambda: None)
response_headers = ctx.get("response_headers", {})
print(f"✅ CORS头设置: {response_headers}")
# 检查关键安全头
expected_headers = [
"Access-Control-Allow-Origin",
"Access-Control-Allow-Methods",
"Access-Control-Allow-Headers"
]
for header in expected_headers:
if header in response_headers:
print(f"{header}: {response_headers[header]}")
else:
print(f"❌ 缺少 {header}")
return True
except Exception as e:
print(f"❌ 安全头测试失败: {e}")
return False
def test_configuration_overrides():
"""测试配置覆盖"""
print("\n=== 测试配置覆盖 ===")
import os
# 测试环境变量覆盖
os.environ["CORS_ALLOWED_ORIGINS"] = '["https://example.com"]'
os.environ["RATE_LIMIT_MAX_REQUESTS"] = "50"
os.environ["CSRF_ENABLED"] = "false"
try:
config = Config()
cors_origins = config.get("CORS_ALLOWED_ORIGINS")
rate_limit_max = config.get("RATE_LIMIT_MAX_REQUESTS")
csrf_enabled = config.get("CSRF_ENABLED")
print(f"✅ 环境变量覆盖 CORS: {cors_origins}")
print(f"✅ 环境变量覆盖 限流: {rate_limit_max}")
print(f"✅ 环境变量覆盖 CSRF: {csrf_enabled}")
return True
except Exception as e:
print(f"❌ 配置覆盖测试失败: {e}")
return False
finally:
# 清理环境变量
for key in ["CORS_ALLOWED_ORIGINS", "RATE_LIMIT_MAX_REQUESTS", "CSRF_ENABLED"]:
if key in os.environ:
del os.environ[key]
if __name__ == "__main__":
print("开始NebulaShell安全改进验证测试...")
tests = [
("安全配置测试", test_security_configurations),
("限流功能测试", test_rate_limiting),
("CSRF防护测试", test_csrf_protection),
("输入验证测试", test_input_validation),
("中间件链测试", test_middleware_chain),
("安全头测试", test_security_headers),
("配置覆盖测试", test_configuration_overrides),
]
passed = 0
total = len(tests)
for test_name, test_func in tests:
print(f"\n--- {test_name} ---")
if test_func():
passed += 1
print(f"{test_name} 通过")
else:
print(f"{test_name} 失败")
print(f"\n--- 测试结果 ---")
print(f"通过: {passed}/{total}")
if passed == total:
print("🎉 所有安全改进测试通过!")
print("\n安全改进总结:")
print("✅ 限流防护 - 防止DoS攻击")
print("✅ CSRF防护 - 防止跨站请求伪造")
print("✅ 输入验证 - 防止注入攻击")
print("✅ CORS安全 - 限制跨域访问")
print("✅ 安全头 - 设置适当的安全响应头")
print("✅ 配置管理 - 支持环境变量覆盖")
sys.exit(0)
else:
print("❌ 部分测试失败,需要修复。")
sys.exit(1)