修复项目主要错误

This commit is contained in:
Falck
2026-05-04 21:19:34 +08:00
parent ba58b3939a
commit 4441a968db
20 changed files with 639 additions and 317 deletions

View File

@@ -42,4 +42,15 @@ class QualityCheck:
return issues
def _calculate_complexity(self, node: ast.AST) -> int:
pass
"""计算圈复杂度"""
complexity = 1
for child in ast.walk(node):
if isinstance(child, (ast.If, ast.While, ast.For, ast.AsyncFor)):
complexity += 1
elif isinstance(child, ast.ExceptHandler):
complexity += 1
elif isinstance(child, ast.BoolOp):
complexity += len(child.values) - 1
elif isinstance(child, (ast.And, ast.Or)):
complexity += 1
return complexity

View File

@@ -1,3 +1,8 @@
import ast
from pathlib import Path
from typing import Optional
class ReferenceCheck:
STD_MODULES = {
'os', 'sys', 'json', 're', 'time', 'datetime', 'pathlib',
@@ -41,30 +46,40 @@ class ReferenceCheck:
self._scan_project_modules()
def _scan_project_modules(self):
if dir_path.exists():
for item in dir_path.iterdir():
if item.is_file() and item.name.endswith(".py") and item.name != "__init__.py":
module_name = item.name[:-3]
full_name = f"{base_name}.{module_name}"
self._available_modules.add(full_name)
elif item.is_dir() and (item / "__init__.py").exists():
full_name = f"{base_name}.{item.name}"
self._available_modules.add(full_name)
self._scan_module_dir(item, full_name)
"""扫描项目目录下的所有 Python 模块"""
store_dir = self.project_root / "store"
if not store_dir.exists():
return
for author_dir in store_dir.iterdir():
if not author_dir.is_dir():
continue
for plugin_dir in author_dir.iterdir():
if not plugin_dir.is_dir():
continue
self._scan_plugin_modules(plugin_dir, plugin_dir.name)
def _scan_plugin_modules(self, plugin_dir: Path, base_name: str):
if dir_path.exists():
for item in dir_path.iterdir():
if item.is_file() and item.name.endswith(".py") and item.name != "__init__.py":
module_name = item.name[:-3]
self._available_modules.add(f"{base_name}.{module_name}")
elif item.is_dir() and (item / "__init__.py").exists():
self._add_module_from_dir(item, f"{base_name}.{item.name}")
"""扫描单个插件目录下的模块"""
if not plugin_dir.exists():
return
for item in plugin_dir.iterdir():
if item.is_file() and item.name.endswith(".py") and item.name != "__init__.py":
module_name = item.name[:-3]
self._available_modules.add(f"{base_name}.{module_name}")
elif item.is_dir() and (item / "__init__.py").exists():
self._available_modules.add(f"{base_name}.{item.name}")
def check(self, filepath: str, content: str) -> list:
issues = []
file_path = Path(filepath)
try:
tree = ast.parse(content)
except SyntaxError:
return []
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
@@ -103,25 +118,8 @@ class ReferenceCheck:
return issues
def _check_variable_references(self, filepath: str, tree: ast.AST, content: str) -> list:
issues = []
for node in ast.walk(tree):
if isinstance(node, ast.Attribute):
if isinstance(node.value, ast.Name):
var_name = node.value.id
if var_name in ('None', 'True', 'False'):
issues.append({
"file": filepath,
"line": node.lineno,
"severity": "critical",
"type": "attribute_error",
"message": f"尝试访问 {var_name} 的属性: {node.attr}"
})
return issues
def _check_function_calls(self, filepath: str, tree: ast.AST, content: str) -> list:
def _is_module_available(self, module_name: str, file_path: Optional[Path] = None) -> bool:
"""检查模块是否可用"""
if module_name in self._available_modules:
return True
@@ -154,5 +152,49 @@ class ReferenceCheck:
return False
def _check_variable_references(self, filepath: str, tree: ast.AST, content: str) -> list:
issues = []
for node in ast.walk(tree):
if isinstance(node, ast.Attribute):
if isinstance(node.value, ast.Name):
var_name = node.value.id
if var_name in ('None', 'True', 'False'):
issues.append({
"file": filepath,
"line": node.lineno,
"severity": "critical",
"type": "attribute_error",
"message": f"尝试访问 {var_name} 的属性: {node.attr}"
})
return issues
def _is_name_defined(self, name: str, tree: ast.AST, line: int) -> bool:
pass
"""检查变量名是否在 AST 中定义"""
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
if node.name == name:
return True
for arg in node.args.args:
if arg.arg == name:
return True
elif isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name) and target.id == name:
return True
elif isinstance(node, ast.AnnAssign):
if isinstance(node.target, ast.Name) and node.target.id == name:
return True
elif isinstance(node, ast.ClassDef):
if node.name == name:
return True
elif isinstance(node, ast.With):
for item in node.items:
if item.optional_vars and isinstance(item.optional_vars, ast.Name):
if item.optional_vars.id == name:
return True
elif isinstance(node, ast.ExceptHandler):
if node.name and node.name == name:
return True
return False

View File

@@ -1,34 +1,78 @@
class Reviewer:
import ast
import time
from pathlib import Path
from typing import Optional
from checks.security import SecurityCheck
from checks.quality import QualityCheck
from checks.style import StyleCheck
from checks.references import ReferenceCheck
from report.formatter import Formatter as ReportFormatter
class CodeReviewer:
def __init__(self, config: dict):
self.config = config
self.security = SecurityChecker()
self.quality = QualityChecker()
self.style = StyleChecker()
self.references = ReferenceChecker()
self.security = SecurityCheck()
self.quality = QualityCheck()
self.style = StyleCheck()
self.references = ReferenceCheck()
self.formatter = ReportFormatter(config.get("report_format", "console"))
def run_check(self, scan_dirs: list) -> dict:
issues = []
files_scanned = 0
start_time = time.time()
try:
with open(filepath, 'r', encoding='utf-8') as f:
content = f.read()
exclude_patterns = self.config.get("exclude_patterns", ["__pycache__"])
max_file_size = self.config.get("max_file_size", 102400)
issues.extend(self.security.check(filepath, content))
for scan_dir in scan_dirs:
scan_path = Path(scan_dir)
if not scan_path.exists() or not scan_path.is_dir():
continue
issues.extend(self.quality.check(filepath, content))
for py_file in scan_path.rglob("*.py"):
# 跳过排除目录
if any(part in exclude_patterns for part in py_file.parts):
continue
issues.extend(self.style.check(filepath, content))
filepath = str(py_file)
issues.extend(self.references.check(filepath, content))
# 检查文件大小
if py_file.stat().st_size > max_file_size:
issues.append({
"file": filepath,
"line": 0,
"severity": "warning",
"type": "file_too_large",
"message": f"文件过大 ({py_file.stat().st_size} 字节),跳过检查"
})
continue
except Exception as e:
issues.append({
"file": filepath,
"line": 0,
"severity": "error",
"type": "parse_error",
"message": f"文件解析失败: {e}"
})
try:
with open(filepath, 'r', encoding='utf-8') as f:
content = f.read()
return issues
files_scanned += 1
issues.extend(self.security.check(filepath, content))
issues.extend(self.quality.check(filepath, content))
issues.extend(self.style.check(filepath, content))
issues.extend(self.references.check(filepath, content))
except Exception as e:
issues.append({
"file": filepath,
"line": 0,
"severity": "error",
"type": "parse_error",
"message": f"文件解析失败: {e}"
})
scan_time = round(time.time() - start_time, 2)
return {
"files_scanned": files_scanned,
"total_issues": len(issues),
"issues": issues,
"scan_time": scan_time
}

View File

@@ -36,6 +36,7 @@ class CodeReviewerPlugin:
"report_format": config.get("report_format", "console")
}
from core.reviewer import CodeReviewer
self.reviewer = CodeReviewer(self.config)
Log.info("code-reviewer", "初始化完成")
@@ -46,4 +47,9 @@ class CodeReviewerPlugin:
Log.error("code-reviewer", "插件已停止")
def check(self, dirs: list = None) -> dict:
pass
if not self.reviewer:
return {"error": "code-reviewer 未初始化"}
scan_dirs = dirs if dirs else self.config.get("scan_dirs", ["store", "oss"])
result = self.reviewer.run_check(scan_dirs)
return result

View File

@@ -38,4 +38,6 @@ class Formatter:
return '\n'.join(lines)
def _format_json(self, result: dict) -> str:
pass
"""以 JSON 格式输出审查报告"""
import json
return json.dumps(result, ensure_ascii=False, indent=2)

View File

@@ -1,6 +1,7 @@
"""
CSRF 防护中间件
"""
import json
import hashlib
import secrets
import time
@@ -165,7 +166,6 @@ class CsrfMiddleware:
csrf_token = None
if request.headers.get("Content-Type") == "application/json":
try:
import json
body = json.loads(request.body)
csrf_token = body.get("csrf_token")
except:

View File

@@ -1,13 +1,27 @@
class HttpApiPlugin:
import json
from oss.plugin.types import Plugin, register_plugin_type
from .server import HttpServer, Response
from .router import HttpRouter
from .middleware import MiddlewareChain
class HttpApiPlugin(Plugin):
def __init__(self):
self.server = None
self.router = Router()
self.router = HttpRouter()
self.middleware = MiddlewareChain()
def init(self, deps: dict = None):
self.server = HttpServer(
router=self.router,
middleware=self.middleware,
)
self.server.start()
def stop(self):
if self.server:
self.server.stop()
return Response(
status=200,
body=json.dumps({"status": "ok", "service": "http-api"}),

View File

@@ -1,6 +1,8 @@
"""中间件链 - CORS/鉴权/日志/限流/CSRF/输入验证等"""
import json
import time
import threading
from collections import deque
from typing import Callable, Optional, Any
from oss.config import get_config
@@ -110,13 +112,7 @@ class RateLimitMiddleware(Middleware):
# 请求记录
self.requests = {}
self.lock = None # 延迟初始化
def _init_lock(self):
"""延迟初始化锁"""
if self.lock is None:
import threading
self.lock = threading.Lock()
self.lock = threading.Lock()
def _get_client_identifier(self, request: Request) -> str:
"""获取客户端标识符"""
@@ -152,21 +148,22 @@ class RateLimitMiddleware(Middleware):
max_requests = limit["max_requests"]
time_window = limit["time_window"]
# 清理过期的请求记录
if limit_key not in self.requests:
self.requests[limit_key] = []
request_times = self.requests[limit_key]
while request_times and request_times[0] <= now - time_window:
request_times.popleft()
# 检查是否超过限制
if len(request_times) >= max_requests:
return True
# 记录当前请求
request_times.append(now)
return False
with self.lock:
# 清理过期的请求记录
if limit_key not in self.requests:
self.requests[limit_key] = deque()
request_times = self.requests[limit_key]
while request_times and request_times[0] <= now - time_window:
request_times.popleft()
# 检查是否超过限制
if len(request_times) >= max_requests:
return True
# 记录当前请求
request_times.append(now)
return False
def _create_rate_limit_response(self) -> Response:
"""创建限流响应"""
@@ -191,7 +188,6 @@ class RateLimitMiddleware(Middleware):
return next_fn()
# 获取客户端标识符
self._init_lock()
identifier = self._get_client_identifier(request)
# 检查是否被限流

View File

@@ -1,14 +1,11 @@
"""
限流中间件 - 防止DoS攻击
限流工具 - 令牌桶限流器
"""
import time
import threading
from typing import Dict, Optional
from typing import Dict
from collections import defaultdict, deque
from oss.config import get_config
from store.NebulaShell.http_api.server import Response
class RateLimiter:
"""令牌桶限流器"""
@@ -35,88 +32,4 @@ class RateLimiter:
# 记录当前请求
request_times.append(now)
return True
class RateLimitMiddleware:
"""限流中间件"""
def __init__(self):
self.config = get_config()
self.limiter = RateLimiter(
max_requests=self.config.get("RATE_LIMIT_MAX_REQUESTS", 100),
time_window=self.config.get("RATE_LIMIT_TIME_WINDOW", 60)
)
self.enabled = self.config.get("RATE_LIMIT_ENABLED", True)
# 不同端点的限流配置
self.endpoint_limits = {
"/api/dashboard/stats": {
"max_requests": 10,
"time_window": 60
},
"/api/pkg-manager/search": {
"max_requests": 50,
"time_window": 60
}
}
def get_client_identifier(self, request) -> str:
"""获取客户端标识符"""
# 优先使用IP地址
ip = request.headers.get("X-Forwarded-For", request.headers.get("X-Real-IP", ""))
if not ip:
ip = request.headers.get("Remote-Addr", "unknown")
# 如果有API Key使用Key作为标识符更精确
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
return f"api_key:{auth_header[7:]}"
return f"ip:{ip}"
def get_endpoint_limiter(self, path: str) -> Optional[RateLimiter]:
"""获取端点特定的限流器"""
for endpoint, config in self.endpoint_limits.items():
if path.startswith(endpoint):
return RateLimiter(
max_requests=config["max_requests"],
time_window=config["time_window"]
)
return None
def create_rate_limit_response(self, retry_after: int = 60) -> Response:
"""创建限流响应"""
return Response(
status=429,
headers={
"Content-Type": "application/json",
"Retry-After": str(retry_after),
"X-Rate-Limit-Limit": str(self.limiter.max_requests),
"X-Rate-Limit-Window": str(self.limiter.time_window),
},
body='{"error": "Rate limit exceeded", "message": "请稍后再试"}'
)
def process(self, ctx: dict, next_fn) -> Optional[Response]:
"""处理限流逻辑"""
if not self.enabled:
return next_fn()
request = ctx.get("request")
if not request:
return next_fn()
# 获取客户端标识符
identifier = self.get_client_identifier(request)
# 获取端点特定的限流器
endpoint_limiter = self.get_endpoint_limiter(request.path)
limiter = endpoint_limiter or self.limiter
# 检查是否允许请求
if not limiter.is_allowed(identifier):
retry_after = self.limiter.time_window
return self.create_rate_limit_response(retry_after)
return next_fn()
return True

View File

@@ -1,4 +1,41 @@
"""HTTP 路由 - 基于 oss/shared/router.py 的 BaseRouter"""
import json
from typing import Callable
from oss.shared.router import BaseRouter, BaseRoute, match_path, extract_path_params
from .server import Request, Response
class HttpRouter(BaseRouter):
"""HTTP 路由"""
def add(self, method: str, path: str, handler: Callable):
self.routes.append(BaseRoute(method, path, handler))
class HttpRouter:
def handle(self, request: Request) -> Response:
pass
"""匹配路由并执行处理器"""
for route in self.routes:
if route.method == request.method and match_path(route.path, request.path):
params = extract_path_params(route.path, request.path)
try:
result = route.handler(request, **params)
if isinstance(result, Response):
return result
return Response(
status=200,
body=json.dumps(result) if not isinstance(result, str) else result,
headers={"Content-Type": "application/json"}
)
except Exception as e:
return Response(
status=500,
body=json.dumps({"error": "Internal Server Error", "message": str(e)}),
headers={"Content-Type": "application/json"}
)
# 404 - 无匹配路由
return Response(
status=404,
body=json.dumps({"error": "Not Found", "message": f"路由未找到: {request.method} {request.path}"}),
headers={"Content-Type": "application/json"}
)

View File

@@ -27,7 +27,7 @@ class HttpServer:
def __init__(self, router, middleware, host=None, port=None):
config = get_config()
self.host = host or config.get("HOST", "0.0.0.0")
self.host = host or config.get("HOST", "127.0.0.1")
self.port = port or config.get("HTTP_API_PORT", 8080)
self.router = router
self.middleware = middleware
@@ -68,10 +68,18 @@ class HttpServer:
def do_OPTIONS(self):
"""处理 CORS 预检请求"""
self.send_response(200)
self.send_header("Access-Control-Allow-Origin", "*")
self.send_header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
self.send_header("Access-Control-Allow-Headers", "Content-Type")
config = get_config()
allowed_origins = config.get("CORS_ALLOWED_ORIGINS", ["http://localhost:3000", "http://127.0.0.1:3000"])
origin = self.headers.get("Origin", "")
if origin in allowed_origins or "*" in allowed_origins:
self.send_response(200)
self.send_header("Access-Control-Allow-Origin", origin if origin else "*")
self.send_header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
self.send_header("Access-Control-Allow-Headers", "Content-Type, Authorization")
self.send_header("Access-Control-Allow-Credentials", "true")
else:
self.send_response(204)
self.end_headers()
def _handle(self, method):

View File

@@ -147,6 +147,10 @@ def use(plugin_name: str):
_use_cache[plugin_name] = manager.plugins[plugin_name]
return _use_cache[plugin_name]
# 插件未通过 plugin-loader 加载,记录警告
from oss.logger.logger import Log
Log.warn("plugin-bridge", f"use('{plugin_name}') 绕过 plugin-loader 直接加载,建议通过 plugin-loader 管理插件生命周期")
from oss.config import get_config
config = get_config()
store_dir = Path(config.get("store_dir", "store"))

View File

@@ -429,13 +429,7 @@ class PluginManager:
Log.error("plugin-loader", f"配置文件编码错误:{cf} - {e}")
return {}
# 严格检查:不允许任何代码执行
for p in ['import ', 'from ', 'open(', 'exec(', 'eval(', 'compile(', 'os.', 'sys.', 'subprocess', 'lambda', 'def ', 'class ']:
if p in content:
Log.warn("plugin-loader", f"{cf} 包含危险代码:{p}")
return {}
# 尝试使用 ast.literal_eval 安全解析
# 使用 ast.literal_eval 安全解析(只允许字面量,不会执行代码)
try:
result = ast.literal_eval(content)
if isinstance(result, dict):
@@ -476,13 +470,7 @@ class PluginManager:
Log.error("plugin-loader", f"扩展文件读取失败:{e}")
return {}
# 严格检查:不允许任何代码执行
for p in ['import ', 'from ', 'open(', 'exec(', 'eval(', 'compile(', 'os.', 'sys.', 'subprocess', 'lambda', 'def ', 'class ']:
if p in content:
Log.warn("plugin-loader", f"{ef} 包含危险代码:{p}")
return {}
# 尝试使用 ast.literal_eval 安全解析
# 使用 ast.literal_eval 安全解析(只允许字面量,不会执行代码)
try:
result = ast.literal_eval(content)
if isinstance(result, dict):
@@ -611,29 +599,26 @@ class PluginManager:
if not store_dir.exists(): return
core_plugins = {"webui", "dashboard", "pkg-manager"}
skip = {"plugin-loader"}
first_plugins = []
other_plugins = []
plugin_dirs = []
for ad in store_dir.iterdir():
if ad.is_dir():
for pd in ad.iterdir():
if not pd.is_dir() or pd.name in skip or not (pd / "main.py").exists():
continue
# 读取 load_priority默认为 100
priority = 100
manifest_file = pd / "manifest.json"
is_first = False
if manifest_file.exists():
try:
meta = json.loads(manifest_file.read_text()).get("metadata", {})
if meta.get("load_priority") == "first":
is_first = True
except (json.JSONDecodeError, OSError):
raw = meta.get("load_priority", 100)
priority = 0 if raw == "first" else (int(raw) if isinstance(raw, (int, float)) else 100)
except (json.JSONDecodeError, OSError, (ValueError, TypeError)):
pass
if is_first:
first_plugins.append(pd)
else:
other_plugins.append(pd)
for pd in first_plugins:
self.load(pd, use_sandbox=pd.name not in core_plugins)
for pd in other_plugins:
plugin_dirs.append((priority, pd))
# 按优先级升序排序(数值越小越先加载)
plugin_dirs.sort(key=lambda x: x[0])
for _, pd in plugin_dirs:
self.load(pd, use_sandbox=pd.name not in core_plugins)
self._link_capabilities()

View File

@@ -1,3 +1,7 @@
from typing import Optional
from pathlib import Path
class PluginStorage:
def __init__(self, plugin_name: str, data_dir: str = None):
config = get_config()
@@ -10,39 +14,64 @@ class PluginStorage:
def _load(self):
"""从 data.json 加载持久化数据"""
data_file = self.data_dir / "data.json"
with open(data_file, "w", encoding="utf-8") as f:
json.dump(self._data, f, ensure_ascii=False, indent=2)
if data_file.exists():
try:
with open(data_file, "r", encoding="utf-8") as f:
self._data = json.load(f)
except (json.JSONDecodeError, OSError) as e:
Log.error("plugin-storage", f"加载数据失败 {self.plugin_name}: {e}")
self._data = {}
else:
self._data = {}
def _save(self):
"""将数据持久化到 data.json"""
data_file = self.data_dir / "data.json"
try:
with open(data_file, "w", encoding="utf-8") as f:
json.dump(self._data, f, ensure_ascii=False, indent=2)
except OSError as e:
Log.error("plugin-storage", f"保存数据失败 {self.plugin_name}: {e}")
def get(self, key: str, default: Any = None) -> Any:
with self._lock:
return self._data.get(key, default)
def set(self, key: str, value: Any):
with self._lock:
self._data[key] = value
self._save()
def delete(self, key: str) -> bool:
with self._lock:
return key in self._data
if key in self._data:
del self._data[key]
self._save()
return True
return False
def keys(self) -> list[str]:
with self._lock:
self._data.clear()
self._save()
return list(self._data.keys())
def size(self) -> int:
with self._lock:
return self._data.copy()
return len(self._data)
def set_many(self, data: dict[str, Any]):
return {
"plugin": self.plugin_name,
"keys": self.size(),
"path": str(self.data_dir),
}
with self._lock:
self._data.update(data)
self._save()
def read_file(self, path: str, mode: str = "r") -> Optional[str | bytes]:
try:
file_path = self._resolve_path(path)
if file_path is None:
Log.warn("plugin-storage", f"路径穿越被拒绝: {self.plugin_name}/{path}")
return None
if not file_path.exists() or not file_path.is_file():
return None
with open(file_path, mode, encoding="utf-8" if mode == "r" else None) as f:
@@ -54,6 +83,9 @@ class PluginStorage:
def write_file(self, path: str, content: str | bytes):
try:
file_path = self._resolve_path(path)
if file_path is None:
Log.warn("plugin-storage", f"路径穿越被拒绝: {self.plugin_name}/{path}")
return
file_path.parent.mkdir(parents=True, exist_ok=True)
if isinstance(content, bytes):
with open(file_path, "wb") as f:
@@ -67,6 +99,9 @@ class PluginStorage:
def delete_file(self, path: str) -> bool:
try:
file_path = self._resolve_path(path)
if file_path is None:
Log.warn("plugin-storage", f"路径穿越被拒绝: {self.plugin_name}/{path}")
return False
if file_path.exists() and file_path.is_file():
file_path.unlink()
return True
@@ -78,6 +113,9 @@ class PluginStorage:
def list_files(self, prefix: str = "") -> list[str]:
try:
search_dir = self._resolve_path(prefix) if prefix else self.data_dir
if search_dir is None:
Log.warn("plugin-storage", f"路径穿越被拒绝: {self.plugin_name}/{prefix}")
return []
if not search_dir.exists():
return []
files = []
@@ -91,15 +129,14 @@ class PluginStorage:
def file_exists(self, path: str) -> bool:
file_path = self._resolve_path(path)
if file_path is None:
return False
return file_path.exists() and file_path.is_file()
def serve_file(self, path: str):
try:
file_path = self._resolve_path(path)
try:
file_path.resolve().relative_to(self.data_dir.resolve())
except ValueError:
if file_path is None:
return Response(status=403, body="Forbidden: path traversal detected")
if not file_path.exists() or not file_path.is_file():
@@ -128,8 +165,19 @@ class PluginStorage:
except Exception as e:
return Response(status=500, body=f"Error serving file: {e}")
def _resolve_path(self, path: str) -> Path:
return self.data_dir.resolve()
def _resolve_path(self, path: str) -> Optional[Path]:
"""安全解析路径,防止路径穿越
将 path 拼接到 data_dir 下resolve 后校验是否仍在 data_dir 范围内。
如果 path 试图穿越到 data_dir 之外,返回 None。
"""
try:
target = (self.data_dir / path).resolve()
# 校验是否仍在 data_dir 范围内
target.relative_to(self.data_dir.resolve())
return target
except (ValueError, OSError):
return None
class SharedStorage:
@@ -161,17 +209,19 @@ class PluginStoragePlugin(Plugin):
self.config = {}
self.data_root = Path("./data")
def start(self):
Log.info("plugin-storage", f"插件存储服务已启动 (root={self.data_root})")
def stop(self):
def init(self, deps: dict = None):
"""初始化时加载配置并初始化共享存储"""
config_path = Path("./data/plugin-storage/config.json")
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as f:
self.config = json.load(f)
self.data_root = Path(self.config.get("data_root", "./data"))
shared_dir_name = self.config.get("shared_dir", "DCIM")
shared_dir = self.data_root / shared_dir_name
try:
with open(config_path, "r", encoding="utf-8") as f:
self.config = json.load(f)
self.data_root = Path(self.config.get("data_root", "./data"))
shared_dir_name = self.config.get("shared_dir", "DCIM")
shared_dir = self.data_root / shared_dir_name
except (json.JSONDecodeError, OSError) as e:
Log.error("plugin-storage", f"加载配置失败: {e}")
shared_dir = self.data_root / "DCIM"
else:
Log.warn("plugin-storage", "config.json 不存在,使用默认配置")
self.config = {"data_root": "./data", "shared_dir": "DCIM"}
@@ -180,6 +230,12 @@ class PluginStoragePlugin(Plugin):
self.shared = SharedStorage(self, shared_dir=shared_dir)
def start(self):
Log.info("plugin-storage", f"插件存储服务已启动 (root={self.data_root})")
def stop(self):
Log.info("plugin-storage", "插件存储服务已停止")
def get_storage(self, plugin_name: str) -> PluginStorage:
if plugin_name not in self.storages:
self.storages[plugin_name] = PluginStorage(plugin_name)