radar-analysis / test_app.py
chenxingqiang
Optimize model loading and improve user experience
3228ab0
import os
import sys
import logging
# Configure logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def test_imports():
"""测试所有必需的模块都可以导入"""
try:
import torch
logger.info(f"PyTorch版本: {torch.__version__}")
import transformers
logger.info(f"Transformers版本: {transformers.__version__}")
import numpy as np
logger.info(f"NumPy版本: {np.__version__}")
import PIL
logger.info(f"PIL版本: {PIL.__version__}")
import scipy
logger.info(f"SciPy版本: {scipy.__version__}")
logger.info("所有导入成功")
return True
except ImportError as e:
logger.error(f"导入错误: {str(e)}")
return False
def test_model_loading():
"""测试模型可以加载"""
try:
from model import RadarDetectionModel
# 检查是否设置了HF_TOKEN环境变量
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
logger.warning("未设置HF_TOKEN环境变量,使用公共模型进行测试")
# 尝试初始化模型,使用较小的公共模型
logger.info("尝试初始化模型(使用较小的公共模型)")
model = RadarDetectionModel(model_name="google/siglip-base-patch16-224")
logger.info("模型初始化成功")
return True
except Exception as e:
logger.error(f"模型加载错误: {str(e)}")
return False
def test_feature_extraction():
"""测试特征提取功能"""
try:
import numpy as np
from PIL import Image
from feature_extraction import extract_features
# 创建一个虚拟图像和检测结果
logger.info("创建虚拟测试数据")
dummy_image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8))
dummy_detection = {
'boxes': [[50, 50, 100, 100]],
'scores': [0.9],
'labels': ['测试']
}
# 提取特征
logger.info("提取特征")
features = extract_features(dummy_image, dummy_detection)
logger.info(f"提取的特征: {features}")
return True
except Exception as e:
logger.error(f"特征提取错误: {str(e)}")
return False
def test_app_initialization():
"""测试应用程序初始化但不加载模型"""
try:
logger.info("测试应用程序初始化")
import app
# 检查应用程序是否已初始化但没有加载模型
logger.info("检查应用程序全局变量")
assert app.model is None, "模型不应该在导入时加载"
assert app.MODEL_INIT_ATTEMPTED is False, "模型初始化尝试标志应为False"
logger.info("应用程序初始化测试通过")
return True
except Exception as e:
logger.error(f"应用程序初始化错误: {str(e)}")
return False
def run_tests():
"""运行所有测试"""
tests = [
("导入测试", test_imports),
("应用程序初始化测试", test_app_initialization),
("模型加载测试", test_model_loading),
("特征提取测试", test_feature_extraction)
]
results = []
for name, test_func in tests:
logger.info(f"运行{name}...")
try:
result = test_func()
results.append((name, result))
logger.info(f"{name}: {'通过' if result else '失败'}")
except Exception as e:
logger.error(f"{name}失败,错误: {str(e)}")
results.append((name, False))
# 打印摘要
logger.info("\n--- 测试摘要 ---")
passed = sum(1 for _, result in results if result)
total = len(results)
logger.info(f"通过: {passed}/{total} 测试")
for name, result in results:
status = "通过" if result else "失败"
logger.info(f"{name}: {status}")
return passed == total
if __name__ == "__main__":
success = run_tests()
sys.exit(0 if success else 1)