Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
def test_imports(): | |
"""测试所有必需的模块都可以导入""" | |
try: | |
import torch | |
logger.info(f"PyTorch版本: {torch.__version__}") | |
import transformers | |
logger.info(f"Transformers版本: {transformers.__version__}") | |
import numpy as np | |
logger.info(f"NumPy版本: {np.__version__}") | |
import PIL | |
logger.info(f"PIL版本: {PIL.__version__}") | |
import scipy | |
logger.info(f"SciPy版本: {scipy.__version__}") | |
logger.info("所有导入成功") | |
return True | |
except ImportError as e: | |
logger.error(f"导入错误: {str(e)}") | |
return False | |
def test_model_loading(): | |
"""测试模型可以加载""" | |
try: | |
from model import RadarDetectionModel | |
# 检查是否设置了HF_TOKEN环境变量 | |
hf_token = os.environ.get("HF_TOKEN") | |
if not hf_token: | |
logger.warning("未设置HF_TOKEN环境变量,使用公共模型进行测试") | |
# 尝试初始化模型,使用较小的公共模型 | |
logger.info("尝试初始化模型(使用较小的公共模型)") | |
model = RadarDetectionModel(model_name="google/siglip-base-patch16-224") | |
logger.info("模型初始化成功") | |
return True | |
except Exception as e: | |
logger.error(f"模型加载错误: {str(e)}") | |
return False | |
def test_feature_extraction(): | |
"""测试特征提取功能""" | |
try: | |
import numpy as np | |
from PIL import Image | |
from feature_extraction import extract_features | |
# 创建一个虚拟图像和检测结果 | |
logger.info("创建虚拟测试数据") | |
dummy_image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)) | |
dummy_detection = { | |
'boxes': [[50, 50, 100, 100]], | |
'scores': [0.9], | |
'labels': ['测试'] | |
} | |
# 提取特征 | |
logger.info("提取特征") | |
features = extract_features(dummy_image, dummy_detection) | |
logger.info(f"提取的特征: {features}") | |
return True | |
except Exception as e: | |
logger.error(f"特征提取错误: {str(e)}") | |
return False | |
def test_app_initialization(): | |
"""测试应用程序初始化但不加载模型""" | |
try: | |
logger.info("测试应用程序初始化") | |
import app | |
# 检查应用程序是否已初始化但没有加载模型 | |
logger.info("检查应用程序全局变量") | |
assert app.model is None, "模型不应该在导入时加载" | |
assert app.MODEL_INIT_ATTEMPTED is False, "模型初始化尝试标志应为False" | |
logger.info("应用程序初始化测试通过") | |
return True | |
except Exception as e: | |
logger.error(f"应用程序初始化错误: {str(e)}") | |
return False | |
def run_tests(): | |
"""运行所有测试""" | |
tests = [ | |
("导入测试", test_imports), | |
("应用程序初始化测试", test_app_initialization), | |
("模型加载测试", test_model_loading), | |
("特征提取测试", test_feature_extraction) | |
] | |
results = [] | |
for name, test_func in tests: | |
logger.info(f"运行{name}...") | |
try: | |
result = test_func() | |
results.append((name, result)) | |
logger.info(f"{name}: {'通过' if result else '失败'}") | |
except Exception as e: | |
logger.error(f"{name}失败,错误: {str(e)}") | |
results.append((name, False)) | |
# 打印摘要 | |
logger.info("\n--- 测试摘要 ---") | |
passed = sum(1 for _, result in results if result) | |
total = len(results) | |
logger.info(f"通过: {passed}/{total} 测试") | |
for name, result in results: | |
status = "通过" if result else "失败" | |
logger.info(f"{name}: {status}") | |
return passed == total | |
if __name__ == "__main__": | |
success = run_tests() | |
sys.exit(0 if success else 1) |