Spaces:
Sleeping
Sleeping
File size: 8,913 Bytes
e7b9fb6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 |
"""
工具函数
"""
import os
import shutil
import tempfile
import logging
from pathlib import Path
from typing import Optional, Dict, Any, List, Tuple
import numpy as np
import trimesh
logger = logging.getLogger(__name__)
def validate_file(file_path: str, max_size_mb: int = 50) -> Tuple[bool, str]:
"""
验证上传的文件
Args:
file_path: 文件路径
max_size_mb: 最大文件大小(MB)
Returns:
(是否有效, 错误信息)
"""
try:
if not os.path.exists(file_path):
return False, "文件不存在"
# 检查文件大小
file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
if file_size_mb > max_size_mb:
return False, f"文件太大: {file_size_mb:.1f}MB > {max_size_mb}MB"
# 检查文件扩展名
file_ext = Path(file_path).suffix.lower()
supported_formats = ['.obj', '.glb', '.ply', '.stl']
if file_ext not in supported_formats:
return False, f"不支持的文件格式: {file_ext}"
# 尝试加载文件
try:
mesh = trimesh.load(file_path, force='mesh')
if not hasattr(mesh, 'vertices') or len(mesh.vertices) == 0:
return False, "文件无法解析为有效的3D模型"
except Exception as e:
return False, f"文件格式错误: {str(e)}"
return True, "文件有效"
except Exception as e:
return False, f"文件验证失败: {str(e)}"
def get_model_info(file_path: str) -> Dict[str, Any]:
"""
获取模型信息
Args:
file_path: 模型文件路径
Returns:
模型信息字典
"""
try:
mesh = trimesh.load(file_path, force='mesh')
# 计算基本信息
vertex_count = len(mesh.vertices) if hasattr(mesh, 'vertices') else 0
face_count = len(mesh.faces) if hasattr(mesh, 'faces') else 0
# 计算包围盒
if vertex_count > 0:
bounds = mesh.bounds
size = bounds[1] - bounds[0]
center = (bounds[0] + bounds[1]) / 2
else:
size = np.array([0, 0, 0])
center = np.array([0, 0, 0])
# 计算表面积和体积
surface_area = mesh.area if hasattr(mesh, 'area') else 0
volume = mesh.volume if hasattr(mesh, 'volume') else 0
return {
'file_name': os.path.basename(file_path),
'file_size_mb': os.path.getsize(file_path) / (1024 * 1024),
'vertex_count': vertex_count,
'face_count': face_count,
'bounding_box': {
'min': bounds[0].tolist() if vertex_count > 0 else [0, 0, 0],
'max': bounds[1].tolist() if vertex_count > 0 else [0, 0, 0],
'size': size.tolist(),
'center': center.tolist()
},
'surface_area': float(surface_area),
'volume': float(volume),
'is_watertight': mesh.is_watertight if hasattr(mesh, 'is_watertight') else False,
'is_closed': mesh.is_closed if hasattr(mesh, 'is_closed') else False
}
except Exception as e:
logger.error(f"Failed to get model info: {str(e)}")
return {
'file_name': os.path.basename(file_path),
'error': str(e)
}
def cleanup_temp_files(temp_dir: str, keep_files: Optional[List[str]] = None):
"""
清理临时文件
Args:
temp_dir: 临时目录
keep_files: 需要保留的文件列表
"""
try:
if not os.path.exists(temp_dir):
return
for file_name in os.listdir(temp_dir):
file_path = os.path.join(temp_dir, file_name)
if keep_files and file_name in keep_files:
continue
try:
if os.path.isfile(file_path):
os.remove(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
logger.warning(f"Failed to remove {file_path}: {str(e)}")
except Exception as e:
logger.error(f"Cleanup failed: {str(e)}")
def format_processing_time(seconds: float) -> str:
"""
格式化处理时间
Args:
seconds: 秒数
Returns:
格式化的时间字符串
"""
if seconds < 60:
return f"{seconds:.1f}秒"
elif seconds < 3600:
minutes = seconds / 60
return f"{minutes:.1f}分钟"
else:
hours = seconds / 3600
return f"{hours:.1f}小时"
def get_prompt_suggestions(model_info: Dict[str, Any]) -> List[str]:
"""
根据模型信息获取提示建议
Args:
model_info: 模型信息
Returns:
提示建议列表
"""
suggestions = []
# 基于文件名的建议
file_name = model_info.get('file_name', '').lower()
if any(keyword in file_name for keyword in ['human', 'person', 'character', 'boy', 'girl']):
suggestions.extend([
"realistic human skeleton for walking animations",
"character with full body rig for game animation",
"human bone structure suitable for motion capture"
])
elif any(keyword in file_name for keyword in ['dog', 'cat', 'animal', 'pet']):
suggestions.extend([
"four-legged animal with spine and tail bones",
"quadruped skeleton for natural movement",
"animal bone structure with flexible spine"
])
elif any(keyword in file_name for keyword in ['bird', 'eagle', 'chicken']):
suggestions.extend([
"bird skeleton with wing bones for flight",
"avian bone structure with hollow bones",
"bird with articulated wings and tail"
])
elif any(keyword in file_name for keyword in ['robot', 'mech', 'mechanical']):
suggestions.extend([
"mechanical robot with joint articulation",
"industrial robot with precise joint control",
"mech suit with hydraulic joint system"
])
else:
suggestions.extend([
"articulated skeleton suitable for animation",
"flexible bone structure for general movement",
"skeleton with natural joint hierarchy"
])
# 基于模型复杂度的建议
vertex_count = model_info.get('vertex_count', 0)
if vertex_count > 10000:
suggestions.append("detailed skeleton for high-poly model")
elif vertex_count < 1000:
suggestions.append("simple skeleton for low-poly model")
return suggestions[:5] # 限制建议数量
def create_processing_status(stage: str, progress: float, message: str) -> Dict[str, Any]:
"""
创建处理状态信息
Args:
stage: 处理阶段
progress: 进度 (0-1)
message: 状态消息
Returns:
状态信息字典
"""
return {
'stage': stage,
'progress': min(max(progress, 0.0), 1.0),
'message': message,
'timestamp': __import__('time').time()
}
def estimate_processing_time(model_info: Dict[str, Any]) -> float:
"""
估算处理时间
Args:
model_info: 模型信息
Returns:
估算的处理时间(秒)
"""
try:
vertex_count = model_info.get('vertex_count', 1000)
face_count = model_info.get('face_count', 1000)
# 基于模型复杂度的简单估算
complexity_factor = (vertex_count + face_count) / 10000
base_time = 30 # 基础处理时间30秒
estimated_time = base_time * (1 + complexity_factor * 0.5)
return min(estimated_time, 120) # 最多120秒
except Exception:
return 60 # 默认60秒
def generate_download_filename(original_name: str, suffix: str) -> str:
"""
生成下载文件名
Args:
original_name: 原始文件名
suffix: 后缀
Returns:
新文件名
"""
base_name = os.path.splitext(original_name)[0]
return f"{base_name}_{suffix}"
def safe_json_serialize(obj: Any) -> Any:
"""
安全的JSON序列化
Args:
obj: 要序列化的对象
Returns:
可序列化的对象
"""
if isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, dict):
return {k: safe_json_serialize(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [safe_json_serialize(item) for item in obj]
else:
return obj |