修复项目主要错误
This commit is contained in:
@@ -42,4 +42,15 @@ class QualityCheck:
|
||||
return issues
|
||||
|
||||
def _calculate_complexity(self, node: ast.AST) -> int:
|
||||
pass
|
||||
"""计算圈复杂度"""
|
||||
complexity = 1
|
||||
for child in ast.walk(node):
|
||||
if isinstance(child, (ast.If, ast.While, ast.For, ast.AsyncFor)):
|
||||
complexity += 1
|
||||
elif isinstance(child, ast.ExceptHandler):
|
||||
complexity += 1
|
||||
elif isinstance(child, ast.BoolOp):
|
||||
complexity += len(child.values) - 1
|
||||
elif isinstance(child, (ast.And, ast.Or)):
|
||||
complexity += 1
|
||||
return complexity
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
import ast
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ReferenceCheck:
|
||||
STD_MODULES = {
|
||||
'os', 'sys', 'json', 're', 'time', 'datetime', 'pathlib',
|
||||
@@ -41,30 +46,40 @@ class ReferenceCheck:
|
||||
self._scan_project_modules()
|
||||
|
||||
def _scan_project_modules(self):
|
||||
if dir_path.exists():
|
||||
for item in dir_path.iterdir():
|
||||
if item.is_file() and item.name.endswith(".py") and item.name != "__init__.py":
|
||||
module_name = item.name[:-3]
|
||||
full_name = f"{base_name}.{module_name}"
|
||||
self._available_modules.add(full_name)
|
||||
elif item.is_dir() and (item / "__init__.py").exists():
|
||||
full_name = f"{base_name}.{item.name}"
|
||||
self._available_modules.add(full_name)
|
||||
self._scan_module_dir(item, full_name)
|
||||
"""扫描项目目录下的所有 Python 模块"""
|
||||
store_dir = self.project_root / "store"
|
||||
if not store_dir.exists():
|
||||
return
|
||||
|
||||
for author_dir in store_dir.iterdir():
|
||||
if not author_dir.is_dir():
|
||||
continue
|
||||
for plugin_dir in author_dir.iterdir():
|
||||
if not plugin_dir.is_dir():
|
||||
continue
|
||||
self._scan_plugin_modules(plugin_dir, plugin_dir.name)
|
||||
|
||||
def _scan_plugin_modules(self, plugin_dir: Path, base_name: str):
|
||||
if dir_path.exists():
|
||||
for item in dir_path.iterdir():
|
||||
if item.is_file() and item.name.endswith(".py") and item.name != "__init__.py":
|
||||
module_name = item.name[:-3]
|
||||
self._available_modules.add(f"{base_name}.{module_name}")
|
||||
elif item.is_dir() and (item / "__init__.py").exists():
|
||||
self._add_module_from_dir(item, f"{base_name}.{item.name}")
|
||||
"""扫描单个插件目录下的模块"""
|
||||
if not plugin_dir.exists():
|
||||
return
|
||||
|
||||
for item in plugin_dir.iterdir():
|
||||
if item.is_file() and item.name.endswith(".py") and item.name != "__init__.py":
|
||||
module_name = item.name[:-3]
|
||||
self._available_modules.add(f"{base_name}.{module_name}")
|
||||
elif item.is_dir() and (item / "__init__.py").exists():
|
||||
self._available_modules.add(f"{base_name}.{item.name}")
|
||||
|
||||
def check(self, filepath: str, content: str) -> list:
|
||||
issues = []
|
||||
file_path = Path(filepath)
|
||||
|
||||
try:
|
||||
tree = ast.parse(content)
|
||||
except SyntaxError:
|
||||
return []
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
@@ -103,25 +118,8 @@ class ReferenceCheck:
|
||||
|
||||
return issues
|
||||
|
||||
def _check_variable_references(self, filepath: str, tree: ast.AST, content: str) -> list:
|
||||
issues = []
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Attribute):
|
||||
if isinstance(node.value, ast.Name):
|
||||
var_name = node.value.id
|
||||
if var_name in ('None', 'True', 'False'):
|
||||
issues.append({
|
||||
"file": filepath,
|
||||
"line": node.lineno,
|
||||
"severity": "critical",
|
||||
"type": "attribute_error",
|
||||
"message": f"尝试访问 {var_name} 的属性: {node.attr}"
|
||||
})
|
||||
|
||||
return issues
|
||||
|
||||
def _check_function_calls(self, filepath: str, tree: ast.AST, content: str) -> list:
|
||||
def _is_module_available(self, module_name: str, file_path: Optional[Path] = None) -> bool:
|
||||
"""检查模块是否可用"""
|
||||
if module_name in self._available_modules:
|
||||
return True
|
||||
|
||||
@@ -154,5 +152,49 @@ class ReferenceCheck:
|
||||
|
||||
return False
|
||||
|
||||
def _check_variable_references(self, filepath: str, tree: ast.AST, content: str) -> list:
|
||||
issues = []
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Attribute):
|
||||
if isinstance(node.value, ast.Name):
|
||||
var_name = node.value.id
|
||||
if var_name in ('None', 'True', 'False'):
|
||||
issues.append({
|
||||
"file": filepath,
|
||||
"line": node.lineno,
|
||||
"severity": "critical",
|
||||
"type": "attribute_error",
|
||||
"message": f"尝试访问 {var_name} 的属性: {node.attr}"
|
||||
})
|
||||
|
||||
return issues
|
||||
|
||||
def _is_name_defined(self, name: str, tree: ast.AST, line: int) -> bool:
|
||||
pass
|
||||
"""检查变量名是否在 AST 中定义"""
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
if node.name == name:
|
||||
return True
|
||||
for arg in node.args.args:
|
||||
if arg.arg == name:
|
||||
return True
|
||||
elif isinstance(node, ast.Assign):
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name) and target.id == name:
|
||||
return True
|
||||
elif isinstance(node, ast.AnnAssign):
|
||||
if isinstance(node.target, ast.Name) and node.target.id == name:
|
||||
return True
|
||||
elif isinstance(node, ast.ClassDef):
|
||||
if node.name == name:
|
||||
return True
|
||||
elif isinstance(node, ast.With):
|
||||
for item in node.items:
|
||||
if item.optional_vars and isinstance(item.optional_vars, ast.Name):
|
||||
if item.optional_vars.id == name:
|
||||
return True
|
||||
elif isinstance(node, ast.ExceptHandler):
|
||||
if node.name and node.name == name:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -1,34 +1,78 @@
|
||||
class Reviewer:
|
||||
import ast
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from checks.security import SecurityCheck
|
||||
from checks.quality import QualityCheck
|
||||
from checks.style import StyleCheck
|
||||
from checks.references import ReferenceCheck
|
||||
from report.formatter import Formatter as ReportFormatter
|
||||
|
||||
|
||||
class CodeReviewer:
|
||||
def __init__(self, config: dict):
|
||||
self.config = config
|
||||
self.security = SecurityChecker()
|
||||
self.quality = QualityChecker()
|
||||
self.style = StyleChecker()
|
||||
self.references = ReferenceChecker()
|
||||
self.security = SecurityCheck()
|
||||
self.quality = QualityCheck()
|
||||
self.style = StyleCheck()
|
||||
self.references = ReferenceCheck()
|
||||
self.formatter = ReportFormatter(config.get("report_format", "console"))
|
||||
|
||||
def run_check(self, scan_dirs: list) -> dict:
|
||||
issues = []
|
||||
files_scanned = 0
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
exclude_patterns = self.config.get("exclude_patterns", ["__pycache__"])
|
||||
max_file_size = self.config.get("max_file_size", 102400)
|
||||
|
||||
issues.extend(self.security.check(filepath, content))
|
||||
for scan_dir in scan_dirs:
|
||||
scan_path = Path(scan_dir)
|
||||
if not scan_path.exists() or not scan_path.is_dir():
|
||||
continue
|
||||
|
||||
issues.extend(self.quality.check(filepath, content))
|
||||
for py_file in scan_path.rglob("*.py"):
|
||||
# 跳过排除目录
|
||||
if any(part in exclude_patterns for part in py_file.parts):
|
||||
continue
|
||||
|
||||
issues.extend(self.style.check(filepath, content))
|
||||
filepath = str(py_file)
|
||||
|
||||
issues.extend(self.references.check(filepath, content))
|
||||
# 检查文件大小
|
||||
if py_file.stat().st_size > max_file_size:
|
||||
issues.append({
|
||||
"file": filepath,
|
||||
"line": 0,
|
||||
"severity": "warning",
|
||||
"type": "file_too_large",
|
||||
"message": f"文件过大 ({py_file.stat().st_size} 字节),跳过检查"
|
||||
})
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
issues.append({
|
||||
"file": filepath,
|
||||
"line": 0,
|
||||
"severity": "error",
|
||||
"type": "parse_error",
|
||||
"message": f"文件解析失败: {e}"
|
||||
})
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
return issues
|
||||
files_scanned += 1
|
||||
issues.extend(self.security.check(filepath, content))
|
||||
issues.extend(self.quality.check(filepath, content))
|
||||
issues.extend(self.style.check(filepath, content))
|
||||
issues.extend(self.references.check(filepath, content))
|
||||
|
||||
except Exception as e:
|
||||
issues.append({
|
||||
"file": filepath,
|
||||
"line": 0,
|
||||
"severity": "error",
|
||||
"type": "parse_error",
|
||||
"message": f"文件解析失败: {e}"
|
||||
})
|
||||
|
||||
scan_time = round(time.time() - start_time, 2)
|
||||
return {
|
||||
"files_scanned": files_scanned,
|
||||
"total_issues": len(issues),
|
||||
"issues": issues,
|
||||
"scan_time": scan_time
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ class CodeReviewerPlugin:
|
||||
"report_format": config.get("report_format", "console")
|
||||
}
|
||||
|
||||
from core.reviewer import CodeReviewer
|
||||
self.reviewer = CodeReviewer(self.config)
|
||||
Log.info("code-reviewer", "初始化完成")
|
||||
|
||||
@@ -46,4 +47,9 @@ class CodeReviewerPlugin:
|
||||
Log.error("code-reviewer", "插件已停止")
|
||||
|
||||
def check(self, dirs: list = None) -> dict:
|
||||
pass
|
||||
if not self.reviewer:
|
||||
return {"error": "code-reviewer 未初始化"}
|
||||
|
||||
scan_dirs = dirs if dirs else self.config.get("scan_dirs", ["store", "oss"])
|
||||
result = self.reviewer.run_check(scan_dirs)
|
||||
return result
|
||||
|
||||
@@ -38,4 +38,6 @@ class Formatter:
|
||||
return '\n'.join(lines)
|
||||
|
||||
def _format_json(self, result: dict) -> str:
|
||||
pass
|
||||
"""以 JSON 格式输出审查报告"""
|
||||
import json
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
CSRF 防护中间件
|
||||
"""
|
||||
import json
|
||||
import hashlib
|
||||
import secrets
|
||||
import time
|
||||
@@ -165,7 +166,6 @@ class CsrfMiddleware:
|
||||
csrf_token = None
|
||||
if request.headers.get("Content-Type") == "application/json":
|
||||
try:
|
||||
import json
|
||||
body = json.loads(request.body)
|
||||
csrf_token = body.get("csrf_token")
|
||||
except:
|
||||
|
||||
@@ -1,13 +1,27 @@
|
||||
class HttpApiPlugin:
|
||||
import json
|
||||
|
||||
from oss.plugin.types import Plugin, register_plugin_type
|
||||
from .server import HttpServer, Response
|
||||
from .router import HttpRouter
|
||||
from .middleware import MiddlewareChain
|
||||
|
||||
|
||||
class HttpApiPlugin(Plugin):
|
||||
def __init__(self):
|
||||
self.server = None
|
||||
self.router = Router()
|
||||
self.router = HttpRouter()
|
||||
self.middleware = MiddlewareChain()
|
||||
|
||||
def init(self, deps: dict = None):
|
||||
self.server = HttpServer(
|
||||
router=self.router,
|
||||
middleware=self.middleware,
|
||||
)
|
||||
self.server.start()
|
||||
|
||||
def stop(self):
|
||||
if self.server:
|
||||
self.server.stop()
|
||||
return Response(
|
||||
status=200,
|
||||
body=json.dumps({"status": "ok", "service": "http-api"}),
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""中间件链 - CORS/鉴权/日志/限流/CSRF/输入验证等"""
|
||||
import json
|
||||
import time
|
||||
import threading
|
||||
from collections import deque
|
||||
from typing import Callable, Optional, Any
|
||||
|
||||
from oss.config import get_config
|
||||
@@ -110,13 +112,7 @@ class RateLimitMiddleware(Middleware):
|
||||
|
||||
# 请求记录
|
||||
self.requests = {}
|
||||
self.lock = None # 延迟初始化
|
||||
|
||||
def _init_lock(self):
|
||||
"""延迟初始化锁"""
|
||||
if self.lock is None:
|
||||
import threading
|
||||
self.lock = threading.Lock()
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def _get_client_identifier(self, request: Request) -> str:
|
||||
"""获取客户端标识符"""
|
||||
@@ -152,21 +148,22 @@ class RateLimitMiddleware(Middleware):
|
||||
max_requests = limit["max_requests"]
|
||||
time_window = limit["time_window"]
|
||||
|
||||
# 清理过期的请求记录
|
||||
if limit_key not in self.requests:
|
||||
self.requests[limit_key] = []
|
||||
|
||||
request_times = self.requests[limit_key]
|
||||
while request_times and request_times[0] <= now - time_window:
|
||||
request_times.popleft()
|
||||
|
||||
# 检查是否超过限制
|
||||
if len(request_times) >= max_requests:
|
||||
return True
|
||||
|
||||
# 记录当前请求
|
||||
request_times.append(now)
|
||||
return False
|
||||
with self.lock:
|
||||
# 清理过期的请求记录
|
||||
if limit_key not in self.requests:
|
||||
self.requests[limit_key] = deque()
|
||||
|
||||
request_times = self.requests[limit_key]
|
||||
while request_times and request_times[0] <= now - time_window:
|
||||
request_times.popleft()
|
||||
|
||||
# 检查是否超过限制
|
||||
if len(request_times) >= max_requests:
|
||||
return True
|
||||
|
||||
# 记录当前请求
|
||||
request_times.append(now)
|
||||
return False
|
||||
|
||||
def _create_rate_limit_response(self) -> Response:
|
||||
"""创建限流响应"""
|
||||
@@ -191,7 +188,6 @@ class RateLimitMiddleware(Middleware):
|
||||
return next_fn()
|
||||
|
||||
# 获取客户端标识符
|
||||
self._init_lock()
|
||||
identifier = self._get_client_identifier(request)
|
||||
|
||||
# 检查是否被限流
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
"""
|
||||
限流中间件 - 防止DoS攻击
|
||||
限流工具 - 令牌桶限流器
|
||||
"""
|
||||
import time
|
||||
import threading
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict
|
||||
from collections import defaultdict, deque
|
||||
|
||||
from oss.config import get_config
|
||||
from store.NebulaShell.http_api.server import Response
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""令牌桶限流器"""
|
||||
@@ -35,88 +32,4 @@ class RateLimiter:
|
||||
|
||||
# 记录当前请求
|
||||
request_times.append(now)
|
||||
return True
|
||||
|
||||
|
||||
class RateLimitMiddleware:
|
||||
"""限流中间件"""
|
||||
|
||||
def __init__(self):
|
||||
self.config = get_config()
|
||||
self.limiter = RateLimiter(
|
||||
max_requests=self.config.get("RATE_LIMIT_MAX_REQUESTS", 100),
|
||||
time_window=self.config.get("RATE_LIMIT_TIME_WINDOW", 60)
|
||||
)
|
||||
self.enabled = self.config.get("RATE_LIMIT_ENABLED", True)
|
||||
|
||||
# 不同端点的限流配置
|
||||
self.endpoint_limits = {
|
||||
"/api/dashboard/stats": {
|
||||
"max_requests": 10,
|
||||
"time_window": 60
|
||||
},
|
||||
"/api/pkg-manager/search": {
|
||||
"max_requests": 50,
|
||||
"time_window": 60
|
||||
}
|
||||
}
|
||||
|
||||
def get_client_identifier(self, request) -> str:
|
||||
"""获取客户端标识符"""
|
||||
# 优先使用IP地址
|
||||
ip = request.headers.get("X-Forwarded-For", request.headers.get("X-Real-IP", ""))
|
||||
if not ip:
|
||||
ip = request.headers.get("Remote-Addr", "unknown")
|
||||
|
||||
# 如果有API Key,使用Key作为标识符(更精确)
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
return f"api_key:{auth_header[7:]}"
|
||||
|
||||
return f"ip:{ip}"
|
||||
|
||||
def get_endpoint_limiter(self, path: str) -> Optional[RateLimiter]:
|
||||
"""获取端点特定的限流器"""
|
||||
for endpoint, config in self.endpoint_limits.items():
|
||||
if path.startswith(endpoint):
|
||||
return RateLimiter(
|
||||
max_requests=config["max_requests"],
|
||||
time_window=config["time_window"]
|
||||
)
|
||||
return None
|
||||
|
||||
def create_rate_limit_response(self, retry_after: int = 60) -> Response:
|
||||
"""创建限流响应"""
|
||||
return Response(
|
||||
status=429,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Retry-After": str(retry_after),
|
||||
"X-Rate-Limit-Limit": str(self.limiter.max_requests),
|
||||
"X-Rate-Limit-Window": str(self.limiter.time_window),
|
||||
},
|
||||
body='{"error": "Rate limit exceeded", "message": "请稍后再试"}'
|
||||
)
|
||||
|
||||
def process(self, ctx: dict, next_fn) -> Optional[Response]:
|
||||
"""处理限流逻辑"""
|
||||
if not self.enabled:
|
||||
return next_fn()
|
||||
|
||||
request = ctx.get("request")
|
||||
if not request:
|
||||
return next_fn()
|
||||
|
||||
# 获取客户端标识符
|
||||
identifier = self.get_client_identifier(request)
|
||||
|
||||
# 获取端点特定的限流器
|
||||
endpoint_limiter = self.get_endpoint_limiter(request.path)
|
||||
limiter = endpoint_limiter or self.limiter
|
||||
|
||||
# 检查是否允许请求
|
||||
if not limiter.is_allowed(identifier):
|
||||
retry_after = self.limiter.time_window
|
||||
return self.create_rate_limit_response(retry_after)
|
||||
|
||||
return next_fn()
|
||||
return True
|
||||
@@ -1,4 +1,41 @@
|
||||
"""HTTP 路由 - 基于 oss/shared/router.py 的 BaseRouter"""
|
||||
import json
|
||||
from typing import Callable
|
||||
|
||||
from oss.shared.router import BaseRouter, BaseRoute, match_path, extract_path_params
|
||||
from .server import Request, Response
|
||||
|
||||
|
||||
class HttpRouter(BaseRouter):
|
||||
"""HTTP 路由"""
|
||||
|
||||
def add(self, method: str, path: str, handler: Callable):
|
||||
self.routes.append(BaseRoute(method, path, handler))
|
||||
|
||||
class HttpRouter:
|
||||
def handle(self, request: Request) -> Response:
|
||||
pass
|
||||
"""匹配路由并执行处理器"""
|
||||
for route in self.routes:
|
||||
if route.method == request.method and match_path(route.path, request.path):
|
||||
params = extract_path_params(route.path, request.path)
|
||||
try:
|
||||
result = route.handler(request, **params)
|
||||
if isinstance(result, Response):
|
||||
return result
|
||||
return Response(
|
||||
status=200,
|
||||
body=json.dumps(result) if not isinstance(result, str) else result,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
except Exception as e:
|
||||
return Response(
|
||||
status=500,
|
||||
body=json.dumps({"error": "Internal Server Error", "message": str(e)}),
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
# 404 - 无匹配路由
|
||||
return Response(
|
||||
status=404,
|
||||
body=json.dumps({"error": "Not Found", "message": f"路由未找到: {request.method} {request.path}"}),
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
@@ -27,7 +27,7 @@ class HttpServer:
|
||||
|
||||
def __init__(self, router, middleware, host=None, port=None):
|
||||
config = get_config()
|
||||
self.host = host or config.get("HOST", "0.0.0.0")
|
||||
self.host = host or config.get("HOST", "127.0.0.1")
|
||||
self.port = port or config.get("HTTP_API_PORT", 8080)
|
||||
self.router = router
|
||||
self.middleware = middleware
|
||||
@@ -68,10 +68,18 @@ class HttpServer:
|
||||
|
||||
def do_OPTIONS(self):
|
||||
"""处理 CORS 预检请求"""
|
||||
self.send_response(200)
|
||||
self.send_header("Access-Control-Allow-Origin", "*")
|
||||
self.send_header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
self.send_header("Access-Control-Allow-Headers", "Content-Type")
|
||||
config = get_config()
|
||||
allowed_origins = config.get("CORS_ALLOWED_ORIGINS", ["http://localhost:3000", "http://127.0.0.1:3000"])
|
||||
origin = self.headers.get("Origin", "")
|
||||
|
||||
if origin in allowed_origins or "*" in allowed_origins:
|
||||
self.send_response(200)
|
||||
self.send_header("Access-Control-Allow-Origin", origin if origin else "*")
|
||||
self.send_header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
self.send_header("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
self.send_header("Access-Control-Allow-Credentials", "true")
|
||||
else:
|
||||
self.send_response(204)
|
||||
self.end_headers()
|
||||
|
||||
def _handle(self, method):
|
||||
|
||||
@@ -147,6 +147,10 @@ def use(plugin_name: str):
|
||||
_use_cache[plugin_name] = manager.plugins[plugin_name]
|
||||
return _use_cache[plugin_name]
|
||||
|
||||
# 插件未通过 plugin-loader 加载,记录警告
|
||||
from oss.logger.logger import Log
|
||||
Log.warn("plugin-bridge", f"use('{plugin_name}') 绕过 plugin-loader 直接加载,建议通过 plugin-loader 管理插件生命周期")
|
||||
|
||||
from oss.config import get_config
|
||||
config = get_config()
|
||||
store_dir = Path(config.get("store_dir", "store"))
|
||||
|
||||
@@ -429,13 +429,7 @@ class PluginManager:
|
||||
Log.error("plugin-loader", f"配置文件编码错误:{cf} - {e}")
|
||||
return {}
|
||||
|
||||
# 严格检查:不允许任何代码执行
|
||||
for p in ['import ', 'from ', 'open(', 'exec(', 'eval(', 'compile(', 'os.', 'sys.', 'subprocess', 'lambda', 'def ', 'class ']:
|
||||
if p in content:
|
||||
Log.warn("plugin-loader", f"{cf} 包含危险代码:{p}")
|
||||
return {}
|
||||
|
||||
# 尝试使用 ast.literal_eval 安全解析
|
||||
# 使用 ast.literal_eval 安全解析(只允许字面量,不会执行代码)
|
||||
try:
|
||||
result = ast.literal_eval(content)
|
||||
if isinstance(result, dict):
|
||||
@@ -476,13 +470,7 @@ class PluginManager:
|
||||
Log.error("plugin-loader", f"扩展文件读取失败:{e}")
|
||||
return {}
|
||||
|
||||
# 严格检查:不允许任何代码执行
|
||||
for p in ['import ', 'from ', 'open(', 'exec(', 'eval(', 'compile(', 'os.', 'sys.', 'subprocess', 'lambda', 'def ', 'class ']:
|
||||
if p in content:
|
||||
Log.warn("plugin-loader", f"{ef} 包含危险代码:{p}")
|
||||
return {}
|
||||
|
||||
# 尝试使用 ast.literal_eval 安全解析
|
||||
# 使用 ast.literal_eval 安全解析(只允许字面量,不会执行代码)
|
||||
try:
|
||||
result = ast.literal_eval(content)
|
||||
if isinstance(result, dict):
|
||||
@@ -611,29 +599,26 @@ class PluginManager:
|
||||
if not store_dir.exists(): return
|
||||
core_plugins = {"webui", "dashboard", "pkg-manager"}
|
||||
skip = {"plugin-loader"}
|
||||
first_plugins = []
|
||||
other_plugins = []
|
||||
plugin_dirs = []
|
||||
for ad in store_dir.iterdir():
|
||||
if ad.is_dir():
|
||||
for pd in ad.iterdir():
|
||||
if not pd.is_dir() or pd.name in skip or not (pd / "main.py").exists():
|
||||
continue
|
||||
# 读取 load_priority,默认为 100
|
||||
priority = 100
|
||||
manifest_file = pd / "manifest.json"
|
||||
is_first = False
|
||||
if manifest_file.exists():
|
||||
try:
|
||||
meta = json.loads(manifest_file.read_text()).get("metadata", {})
|
||||
if meta.get("load_priority") == "first":
|
||||
is_first = True
|
||||
except (json.JSONDecodeError, OSError):
|
||||
raw = meta.get("load_priority", 100)
|
||||
priority = 0 if raw == "first" else (int(raw) if isinstance(raw, (int, float)) else 100)
|
||||
except (json.JSONDecodeError, OSError, (ValueError, TypeError)):
|
||||
pass
|
||||
if is_first:
|
||||
first_plugins.append(pd)
|
||||
else:
|
||||
other_plugins.append(pd)
|
||||
for pd in first_plugins:
|
||||
self.load(pd, use_sandbox=pd.name not in core_plugins)
|
||||
for pd in other_plugins:
|
||||
plugin_dirs.append((priority, pd))
|
||||
# 按优先级升序排序(数值越小越先加载)
|
||||
plugin_dirs.sort(key=lambda x: x[0])
|
||||
for _, pd in plugin_dirs:
|
||||
self.load(pd, use_sandbox=pd.name not in core_plugins)
|
||||
self._link_capabilities()
|
||||
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class PluginStorage:
|
||||
def __init__(self, plugin_name: str, data_dir: str = None):
|
||||
config = get_config()
|
||||
@@ -10,39 +14,64 @@ class PluginStorage:
|
||||
|
||||
|
||||
def _load(self):
|
||||
"""从 data.json 加载持久化数据"""
|
||||
data_file = self.data_dir / "data.json"
|
||||
with open(data_file, "w", encoding="utf-8") as f:
|
||||
json.dump(self._data, f, ensure_ascii=False, indent=2)
|
||||
if data_file.exists():
|
||||
try:
|
||||
with open(data_file, "r", encoding="utf-8") as f:
|
||||
self._data = json.load(f)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
Log.error("plugin-storage", f"加载数据失败 {self.plugin_name}: {e}")
|
||||
self._data = {}
|
||||
else:
|
||||
self._data = {}
|
||||
|
||||
def _save(self):
|
||||
"""将数据持久化到 data.json"""
|
||||
data_file = self.data_dir / "data.json"
|
||||
try:
|
||||
with open(data_file, "w", encoding="utf-8") as f:
|
||||
json.dump(self._data, f, ensure_ascii=False, indent=2)
|
||||
except OSError as e:
|
||||
Log.error("plugin-storage", f"保存数据失败 {self.plugin_name}: {e}")
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
with self._lock:
|
||||
return self._data.get(key, default)
|
||||
|
||||
def set(self, key: str, value: Any):
|
||||
with self._lock:
|
||||
self._data[key] = value
|
||||
self._save()
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
with self._lock:
|
||||
return key in self._data
|
||||
if key in self._data:
|
||||
del self._data[key]
|
||||
self._save()
|
||||
return True
|
||||
return False
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
with self._lock:
|
||||
self._data.clear()
|
||||
self._save()
|
||||
return list(self._data.keys())
|
||||
|
||||
def size(self) -> int:
|
||||
with self._lock:
|
||||
return self._data.copy()
|
||||
return len(self._data)
|
||||
|
||||
def set_many(self, data: dict[str, Any]):
|
||||
return {
|
||||
"plugin": self.plugin_name,
|
||||
"keys": self.size(),
|
||||
"path": str(self.data_dir),
|
||||
}
|
||||
with self._lock:
|
||||
self._data.update(data)
|
||||
self._save()
|
||||
|
||||
|
||||
def read_file(self, path: str, mode: str = "r") -> Optional[str | bytes]:
|
||||
try:
|
||||
file_path = self._resolve_path(path)
|
||||
if file_path is None:
|
||||
Log.warn("plugin-storage", f"路径穿越被拒绝: {self.plugin_name}/{path}")
|
||||
return None
|
||||
if not file_path.exists() or not file_path.is_file():
|
||||
return None
|
||||
with open(file_path, mode, encoding="utf-8" if mode == "r" else None) as f:
|
||||
@@ -54,6 +83,9 @@ class PluginStorage:
|
||||
def write_file(self, path: str, content: str | bytes):
|
||||
try:
|
||||
file_path = self._resolve_path(path)
|
||||
if file_path is None:
|
||||
Log.warn("plugin-storage", f"路径穿越被拒绝: {self.plugin_name}/{path}")
|
||||
return
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if isinstance(content, bytes):
|
||||
with open(file_path, "wb") as f:
|
||||
@@ -67,6 +99,9 @@ class PluginStorage:
|
||||
def delete_file(self, path: str) -> bool:
|
||||
try:
|
||||
file_path = self._resolve_path(path)
|
||||
if file_path is None:
|
||||
Log.warn("plugin-storage", f"路径穿越被拒绝: {self.plugin_name}/{path}")
|
||||
return False
|
||||
if file_path.exists() and file_path.is_file():
|
||||
file_path.unlink()
|
||||
return True
|
||||
@@ -78,6 +113,9 @@ class PluginStorage:
|
||||
def list_files(self, prefix: str = "") -> list[str]:
|
||||
try:
|
||||
search_dir = self._resolve_path(prefix) if prefix else self.data_dir
|
||||
if search_dir is None:
|
||||
Log.warn("plugin-storage", f"路径穿越被拒绝: {self.plugin_name}/{prefix}")
|
||||
return []
|
||||
if not search_dir.exists():
|
||||
return []
|
||||
files = []
|
||||
@@ -91,15 +129,14 @@ class PluginStorage:
|
||||
|
||||
def file_exists(self, path: str) -> bool:
|
||||
file_path = self._resolve_path(path)
|
||||
if file_path is None:
|
||||
return False
|
||||
return file_path.exists() and file_path.is_file()
|
||||
|
||||
def serve_file(self, path: str):
|
||||
try:
|
||||
file_path = self._resolve_path(path)
|
||||
|
||||
try:
|
||||
file_path.resolve().relative_to(self.data_dir.resolve())
|
||||
except ValueError:
|
||||
if file_path is None:
|
||||
return Response(status=403, body="Forbidden: path traversal detected")
|
||||
|
||||
if not file_path.exists() or not file_path.is_file():
|
||||
@@ -128,8 +165,19 @@ class PluginStorage:
|
||||
except Exception as e:
|
||||
return Response(status=500, body=f"Error serving file: {e}")
|
||||
|
||||
def _resolve_path(self, path: str) -> Path:
|
||||
return self.data_dir.resolve()
|
||||
def _resolve_path(self, path: str) -> Optional[Path]:
|
||||
"""安全解析路径,防止路径穿越
|
||||
|
||||
将 path 拼接到 data_dir 下,resolve 后校验是否仍在 data_dir 范围内。
|
||||
如果 path 试图穿越到 data_dir 之外,返回 None。
|
||||
"""
|
||||
try:
|
||||
target = (self.data_dir / path).resolve()
|
||||
# 校验是否仍在 data_dir 范围内
|
||||
target.relative_to(self.data_dir.resolve())
|
||||
return target
|
||||
except (ValueError, OSError):
|
||||
return None
|
||||
|
||||
|
||||
class SharedStorage:
|
||||
@@ -161,17 +209,19 @@ class PluginStoragePlugin(Plugin):
|
||||
self.config = {}
|
||||
self.data_root = Path("./data")
|
||||
|
||||
def start(self):
|
||||
Log.info("plugin-storage", f"插件存储服务已启动 (root={self.data_root})")
|
||||
|
||||
def stop(self):
|
||||
def init(self, deps: dict = None):
|
||||
"""初始化时加载配置并初始化共享存储"""
|
||||
config_path = Path("./data/plugin-storage/config.json")
|
||||
if config_path.exists():
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
self.config = json.load(f)
|
||||
self.data_root = Path(self.config.get("data_root", "./data"))
|
||||
shared_dir_name = self.config.get("shared_dir", "DCIM")
|
||||
shared_dir = self.data_root / shared_dir_name
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
self.config = json.load(f)
|
||||
self.data_root = Path(self.config.get("data_root", "./data"))
|
||||
shared_dir_name = self.config.get("shared_dir", "DCIM")
|
||||
shared_dir = self.data_root / shared_dir_name
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
Log.error("plugin-storage", f"加载配置失败: {e}")
|
||||
shared_dir = self.data_root / "DCIM"
|
||||
else:
|
||||
Log.warn("plugin-storage", "config.json 不存在,使用默认配置")
|
||||
self.config = {"data_root": "./data", "shared_dir": "DCIM"}
|
||||
@@ -180,6 +230,12 @@ class PluginStoragePlugin(Plugin):
|
||||
|
||||
self.shared = SharedStorage(self, shared_dir=shared_dir)
|
||||
|
||||
def start(self):
|
||||
Log.info("plugin-storage", f"插件存储服务已启动 (root={self.data_root})")
|
||||
|
||||
def stop(self):
|
||||
Log.info("plugin-storage", "插件存储服务已停止")
|
||||
|
||||
def get_storage(self, plugin_name: str) -> PluginStorage:
|
||||
if plugin_name not in self.storages:
|
||||
self.storages[plugin_name] = PluginStorage(plugin_name)
|
||||
|
||||
Reference in New Issue
Block a user