Spaces:
Runtime error
Runtime error
Debug
Browse files
app.py
CHANGED
|
@@ -20,22 +20,31 @@ subprocess.run(shlex.split("pip install wheel/pointops-1.0-cp310-cp310-linux_x86
|
|
| 20 |
from src.utils.visualization_utils import render_video_from_file
|
| 21 |
from src.model import LSM_MASt3R
|
| 22 |
|
| 23 |
-
#
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
# model_path = hf_hub_download(repo_id=model_repo, filename=model_filename)
|
| 29 |
-
|
| 30 |
-
# Load model
|
| 31 |
-
# model = LSM_MASt3R.from_pretrained(model_path)
|
| 32 |
-
# model = model.eval()
|
| 33 |
|
|
|
|
|
|
|
| 34 |
|
| 35 |
try:
|
| 36 |
-
#
|
| 37 |
-
|
| 38 |
-
print(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
# 加载模型
|
| 41 |
model = LSM_MASt3R.from_pretrained(model_path, device='cuda')
|
|
@@ -43,11 +52,15 @@ try:
|
|
| 43 |
print("模型加载成功并设置为评估模式!")
|
| 44 |
|
| 45 |
except FileNotFoundError:
|
| 46 |
-
print(f"错误:
|
| 47 |
except KeyError as e:
|
| 48 |
print(f"错误: 检查点文件格式不正确,缺少键 {e}。请确认 checkpoint-40.pth 包含 'args' 和 'model'。")
|
| 49 |
except Exception as e:
|
| 50 |
print(f"发生未知错误: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
|
| 53 |
|
|
|
|
| 20 |
from src.utils.visualization_utils import render_video_from_file
|
| 21 |
from src.model import LSM_MASt3R
|
| 22 |
|
| 23 |
+
# 定义相对路径和 Hugging Face 仓库信息
|
| 24 |
+
relative_model_dir = "checkpoints" # 文件夹名称
|
| 25 |
+
relative_model_path = os.path.join(relative_model_dir, "checkpoint-40.pth") # 相对路径
|
| 26 |
+
model_repo = "kairunwen/LSM" # Hugging Face 仓库
|
| 27 |
+
model_filename = "checkpoint-40.pth" # 仓库中的文件名
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
# 转换为绝对路径
|
| 30 |
+
model_path = os.path.abspath(relative_model_path)
|
| 31 |
|
| 32 |
try:
|
| 33 |
+
# 创建 checkpoints 文件夹(如果不存在)
|
| 34 |
+
os.makedirs(relative_model_dir, exist_ok=True)
|
| 35 |
+
print(f"确保 {relative_model_dir} 文件夹存在")
|
| 36 |
+
|
| 37 |
+
# 验证文件是否存在
|
| 38 |
+
if os.path.exists(model_path):
|
| 39 |
+
print(f"找到本地模型文件: {model_path}")
|
| 40 |
+
else:
|
| 41 |
+
print(f"本地模型文件 {model_path} 不存在,正在从 Hugging Face 下载...")
|
| 42 |
+
model_path = hf_hub_download(repo_id=model_repo, filename=model_filename)
|
| 43 |
+
# 可选:将下载的文件移动到 checkpoints 文件夹
|
| 44 |
+
import shutil
|
| 45 |
+
shutil.move(model_path, os.path.abspath(relative_model_path))
|
| 46 |
+
model_path = os.path.abspath(relative_model_path)
|
| 47 |
+
print(f"模型文件已下载并移动到: {model_path}")
|
| 48 |
|
| 49 |
# 加载模型
|
| 50 |
model = LSM_MASt3R.from_pretrained(model_path, device='cuda')
|
|
|
|
| 52 |
print("模型加载成功并设置为评估模式!")
|
| 53 |
|
| 54 |
except FileNotFoundError:
|
| 55 |
+
print(f"错误: 无法找到或下载文件 {model_filename},请检查路径或仓库 {model_repo}。")
|
| 56 |
except KeyError as e:
|
| 57 |
print(f"错误: 检查点文件格式不正确,缺少键 {e}。请确认 checkpoint-40.pth 包含 'args' 和 'model'。")
|
| 58 |
except Exception as e:
|
| 59 |
print(f"发生未知错误: {e}")
|
| 60 |
+
# 调试:检查检查点内容
|
| 61 |
+
ckpt = torch.load(model_path, map_location='cpu')
|
| 62 |
+
print("检查点键:", ckpt.keys())
|
| 63 |
+
print("config.model:", ckpt['args'].model)
|
| 64 |
|
| 65 |
|
| 66 |
|