ComfyUI-Ranking-API / router_messages.py
ZHIWEI666's picture
Upload 5 files
a9092cc verified
# router_messages.py
from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel
import time
import uuid
import subprocess
import os
import 数据库连接 as db
from notifications import add_notification
from models import PrivateMessage
from 安全认证 import require_auth, is_admin
router = APIRouter()
# ==========================================
# 新增:系统公告请求体模型
# ==========================================
class SystemAnnouncement(BaseModel):
admin_account: str
content: str
# ==========================================
# 新增:发布系统公告接口 (仅限管理员,使用JWT验证)
# ==========================================
@router.post("/api/system/announcement")
async def publish_announcement(ann: SystemAnnouncement, current_user: str = Depends(require_auth)):
# 🔒 P0安全修复:使用环境变量配置的管理员列表
if not is_admin(current_user):
raise HTTPException(status_code=403, detail="无权发布系统公告,仅管理员可操作")
# 查询管理员信息
users_db = db.load_data("users.json", default_data={})
admin_info = users_db.get(current_user, {})
announcements_db = db.load_data("announcements.json", default_data=[])
new_ann = {
"id": f"sys_{int(time.time())}_{uuid.uuid4().hex[:6]}",
"type": "system",
"from_user": current_user, # 使用真实的管理员账号
"from_name": admin_info.get("name", current_user), # 使用真实昵称,fallback 为账号
"from_avatar": admin_info.get("avatarDataUrl", ""), # 使用真实头像
"content": ann.content,
"created_at": int(time.time())
}
announcements_db.append(new_ann)
db.save_data("announcements.json", announcements_db)
return {"status": "success"}
# ==========================================
# 管理员调试:执行 Python 脚本
# ==========================================
class AdminScriptRequest(BaseModel):
admin_account: str
script_name: str
# 🔒 P0安全修复:脚本白名单(仅允许执行指定的脚本)
# 警告:添加新脚本前请确保其安全性
ALLOWED_SCRIPTS = {
"密码迁移.py", # 用户密码哈希化迁移
"测试脚本.py", # 接口测试工具
"迁移_余额合并.py", # 一次性余额合并迁移(执行后可移除)
}
@router.post("/api/admin/run-script")
async def run_admin_script(req: AdminScriptRequest, current_user: str = Depends(require_auth)):
"""
管理员专属:执行指定的 Python 脚本
🔒 P0安全修复:白名单 + 路径穿越防护
"""
# 🔒 P0安全修复:使用环境变量配置的管理员列表
if not is_admin(current_user):
raise HTTPException(status_code=403, detail="无权执行此操作,仅管理员可操作")
script_name = req.script_name.strip()
if not script_name:
raise HTTPException(status_code=400, detail="脚本名称不能为空")
# 🔒 P0安全修复:路径穿越攻击防护
if ".." in script_name or "/" in script_name or "\\" in script_name:
raise HTTPException(status_code=400, detail="🚨 安全拦截:脚本名称包含非法字符")
# 🔒 P0安全修复:白名单检查
if script_name not in ALLOWED_SCRIPTS:
raise HTTPException(
status_code=403,
detail=f"🚨 安全拦截:脚本 [{script_name}] 不在白名单中。允许的脚本: {list(ALLOWED_SCRIPTS)}"
)
# 获取当前工作目录
current_dir = os.path.dirname(os.path.abspath(__file__))
script_path = os.path.join(current_dir, script_name)
# 检查文件是否存在
if not os.path.exists(script_path):
return {
"status": "error",
"output": f"❌ 脚本文件不存在: {script_name}\n\n白名单脚本: {list(ALLOWED_SCRIPTS)}"
}
try:
# 执行脚本,设置超时 60 秒
result = subprocess.run(
["python", script_path],
capture_output=True,
text=True,
timeout=60,
cwd=current_dir,
encoding="utf-8"
)
output = ""
if result.stdout:
output += f"📝 标准输出:\n{result.stdout}\n"
if result.stderr:
output += f"\n⚠️ 错误输出:\n{result.stderr}"
if not output:
output = "✅ 脚本执行完成,无输出"
return {
"status": "success" if result.returncode == 0 else "error",
"return_code": result.returncode,
"output": output
}
except subprocess.TimeoutExpired:
return {
"status": "error",
"output": "❌ 脚本执行超时 (60秒)"
}
except Exception as e:
return {
"status": "error",
"output": f"❌ 执行异常: {str(e)}"
}
# ==========================================
# 原有功能:私信与聊天
# ==========================================
@router.post("/api/messages/private")
async def send_private_message(msg: PrivateMessage):
chats_db = db.load_data("chats.json", default_data={})
conv_id = f"{min(msg.sender, msg.receiver)}_{max(msg.sender, msg.receiver)}"
if conv_id not in chats_db: chats_db[conv_id] = []
chat_msg = {"id": f"chat_{int(time.time())}_{uuid.uuid4().hex[:6]}", "sender": msg.sender, "receiver": msg.receiver, "content": msg.content, "created_at": int(time.time()), "is_read": False}
chats_db[conv_id].append(chat_msg)
db.save_data("chats.json", chats_db)
add_notification(msg.receiver, {"type": "private", "from_user": msg.sender, "content": msg.content})
return {"status": "success"}
@router.get("/api/chats/{account}")
async def get_chat_list(account: str):
chats_db = db.load_data("chats.json", default_data={})
users_db = db.load_data("users.json", default_data={})
chat_list = []
for conv_id, msgs in chats_db.items():
if account in conv_id:
other_account = conv_id.replace(account, "").replace("_", "")
if not msgs: continue
last_msg = msgs[-1]
unread_count = sum(1 for m in msgs if m["receiver"] == account and not m.get("is_read"))
other_user = users_db.get(other_account, {})
chat_list.append({
"account": other_account,
"name": other_user.get("name", other_account),
"avatar": other_user.get("avatarDataUrl", ""),
"last_message": last_msg["content"],
"last_time": last_msg["created_at"],
"unread_count": unread_count
})
chat_list.sort(key=lambda x: x["last_time"], reverse=True)
return {"status": "success", "data": chat_list}
@router.get("/api/chats/{account}/{target_account}")
async def get_chat_history(account: str, target_account: str):
chats_db = db.load_data("chats.json", default_data={})
conv_id = f"{min(account, target_account)}_{max(account, target_account)}"
msgs = chats_db.get(conv_id, [])
now = int(time.time())
seven_days = 7 * 24 * 3600
valid_msgs = []
modified = False
for m in msgs:
if not m.get("is_read") or (now - m.get("created_at", 0) < seven_days):
valid_msgs.append(m)
else:
modified = True
# 本次访问即为已读
if m["receiver"] == account and not m.get("is_read"):
m["is_read"] = True
modified = True
if modified or len(valid_msgs) != len(msgs):
chats_db[conv_id] = valid_msgs
db.save_data("chats.json", chats_db)
return {"status": "success", "data": valid_msgs}
# ==========================================
# 改造:获取通知列表 (加入系统公告懒加载注入)
# 使用 atomic_update 避免并发覆盖问题
# 🔥 性能优化:先用只读方式检查是否有实际变更,避免无意义的写入和HF上传
# ==========================================
@router.get("/api/messages/{account}")
async def get_messages(account: str, count_only: bool = False, current_user: str = Depends(require_auth)):
# 🔥 count_only 模式:轻量级轮询,只返回未读数,不标记已读
if count_only:
messages_db = db.load_data("messages.json", default_data={})
user_msgs = messages_db.get(account, [])
now = int(time.time())
seven_days = 7 * 24 * 3600
# 只统计未过期的未读消息
unread = sum(1 for m in user_msgs
if not m.get("is_read")
and (now - m.get("created_at", 0) < seven_days))
return {"status": "success", "unread_count": unread}
# 公告是只读的,先加载
announcements_db = db.load_data("announcements.json", default_data=[])
# 🔥 性能优化:先用只读方式检查是否有实际变更需要写入
messages_db = db.load_data("messages.json", default_data={})
user_msgs = messages_db.get(account, [])
now = int(time.time())
seven_days = 7 * 24 * 3600
# 检查三个条件判断是否需要写入
# 1. 是否有新公告需要注入
user_msg_ids = {m.get("id") for m in user_msgs}
has_new_announcements = any(
ann.get("id") not in user_msg_ids
for ann in announcements_db
)
# 2. 是否有未读消息需要标记已读
has_unread = any(not m.get("is_read") for m in user_msgs)
# 3. 是否有已读超过7天的消息需要清理
has_expired = any(
m.get("is_read") and (now - m.get("created_at", 0) >= seven_days)
for m in user_msgs
)
needs_update = has_new_announcements or has_unread or has_expired
# 🔥 修复:在标记已读之前先计算真实的未读数
unread_before_mark = sum(1 for m in user_msgs
if not m.get("is_read")
and (now - m.get("created_at", 0) < seven_days))
if not needs_update:
# 无变更,直接返回只读数据,不触发写入和HF上传
return {
"status": "success",
"data": user_msgs,
"unread_count": unread_before_mark # 🔥 修复:返回真实的未读数
}
# 有变更需要写入,使用 atomic_update 保证并发安全
result_container = [None]
def updater(data):
user_msgs = data.get(account, [])
# --- 核心改造区:瞬间比对并注入全局公告 ---
user_msg_ids = {m.get("id") for m in user_msgs}
injected = False
for ann in announcements_db:
if ann.get("id") not in user_msg_ids:
new_sys_msg = dict(ann)
new_sys_msg["is_read"] = False
new_sys_msg["receiver"] = account
user_msgs.append(new_sys_msg)
injected = True
if injected:
# 重新按照时间倒序排列,让新公告置顶
user_msgs.sort(key=lambda x: x.get("created_at", 0), reverse=True)
# ----------------------------------------
now = int(time.time())
seven_days = 7 * 24 * 3600
valid = []
# 如果注入了新公告,则判定为需要回写数据库保存
modified = injected
for m in user_msgs:
if not m.get("is_read") or (now - m.get("created_at", 0) < seven_days):
valid.append(m)
else:
modified = True
# 本次访问即为已读 - 将所有未读消息标记为已读
if not m.get("is_read"):
m["is_read"] = True
modified = True
# 原地修改 data,atomic_update 会自动保存
data[account] = valid
# 通过闭包返回结果
result_container[0] = {
"status": "success",
"data": valid,
"unread_count": unread_before_mark # 🔥 修复:返回标记已读前的真实未读数
}
db.atomic_update("messages.json", updater, default_data={})
return result_container[0]
@router.post("/api/messages/{account}/read")
async def mark_messages_read(account: str):
"""
标记消息为已读(原子操作,并发安全)
"""
def updater(data):
user_msgs = data.get(account, [])
modified = False
for m in user_msgs:
if not m.get("is_read"):
m["is_read"] = True
modified = True
# 原地修改 data,atomic_update 会自动保存
if modified:
data[account] = user_msgs
db.atomic_update("messages.json", updater, default_data={})
return {"status": "success"}