import os import sys import logging # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def test_imports(): """测试所有必需的模块都可以导入""" try: import torch logger.info(f"PyTorch版本: {torch.__version__}") import transformers logger.info(f"Transformers版本: {transformers.__version__}") import numpy as np logger.info(f"NumPy版本: {np.__version__}") import PIL logger.info(f"PIL版本: {PIL.__version__}") import scipy logger.info(f"SciPy版本: {scipy.__version__}") logger.info("所有导入成功") return True except ImportError as e: logger.error(f"导入错误: {str(e)}") return False def test_model_loading(): """测试模型可以加载""" try: from model import RadarDetectionModel # 检查是否设置了HF_TOKEN环境变量 hf_token = os.environ.get("HF_TOKEN") if not hf_token: logger.warning("未设置HF_TOKEN环境变量,使用公共模型进行测试") # 尝试初始化模型,使用较小的公共模型 logger.info("尝试初始化模型(使用较小的公共模型)") model = RadarDetectionModel(model_name="google/siglip-base-patch16-224") logger.info("模型初始化成功") return True except Exception as e: logger.error(f"模型加载错误: {str(e)}") return False def test_feature_extraction(): """测试特征提取功能""" try: import numpy as np from PIL import Image from feature_extraction import extract_features # 创建一个虚拟图像和检测结果 logger.info("创建虚拟测试数据") dummy_image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)) dummy_detection = { 'boxes': [[50, 50, 100, 100]], 'scores': [0.9], 'labels': ['测试'] } # 提取特征 logger.info("提取特征") features = extract_features(dummy_image, dummy_detection) logger.info(f"提取的特征: {features}") return True except Exception as e: logger.error(f"特征提取错误: {str(e)}") return False def test_app_initialization(): """测试应用程序初始化但不加载模型""" try: logger.info("测试应用程序初始化") import app # 检查应用程序是否已初始化但没有加载模型 logger.info("检查应用程序全局变量") assert app.model is None, "模型不应该在导入时加载" assert app.MODEL_INIT_ATTEMPTED is False, "模型初始化尝试标志应为False" logger.info("应用程序初始化测试通过") return True except Exception as e: logger.error(f"应用程序初始化错误: {str(e)}") return False def run_tests(): """运行所有测试""" tests = [ ("导入测试", test_imports), ("应用程序初始化测试", test_app_initialization), ("模型加载测试", test_model_loading), ("特征提取测试", test_feature_extraction) ] results = [] for name, test_func in tests: logger.info(f"运行{name}...") try: result = test_func() results.append((name, result)) logger.info(f"{name}: {'通过' if result else '失败'}") except Exception as e: logger.error(f"{name}失败,错误: {str(e)}") results.append((name, False)) # 打印摘要 logger.info("\n--- 测试摘要 ---") passed = sum(1 for _, result in results if result) total = len(results) logger.info(f"通过: {passed}/{total} 测试") for name, result in results: status = "通过" if result else "失败" logger.info(f"{name}: {status}") return passed == total if __name__ == "__main__": success = run_tests() sys.exit(0 if success else 1)