| |
| """ |
| NeuroScan AI 后端 API 测试用例 |
| 保存到 /mnt/ydchen/NeuroScan/test_case/ |
| """ |
|
|
| import os |
| import sys |
| import json |
| import time |
| import requests |
| import numpy as np |
| import nibabel as nib |
| import tempfile |
| import zipfile |
| from pathlib import Path |
| from datetime import datetime |
|
|
| |
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
| |
| BASE_URL = "http://localhost:8080" |
| API_PREFIX = "/api/v1" |
|
|
| |
| TEST_RESULTS_DIR = Path(__file__).parent / "results" |
| TEST_RESULTS_DIR.mkdir(exist_ok=True) |
|
|
|
|
| def log_result(test_name: str, success: bool, message: str, data: dict = None): |
| """记录测试结果""" |
| result = { |
| "test_name": test_name, |
| "success": success, |
| "message": message, |
| "timestamp": datetime.now().isoformat(), |
| "data": data |
| } |
| |
| |
| result_file = TEST_RESULTS_DIR / f"{test_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" |
| with open(result_file, 'w', encoding='utf-8') as f: |
| json.dump(result, f, ensure_ascii=False, indent=2) |
| |
| status = "✅ PASS" if success else "❌ FAIL" |
| print(f"{status} - {test_name}: {message}") |
| return result |
|
|
|
|
| class TestCase: |
| """测试用例基类""" |
| |
| def __init__(self): |
| self.session = requests.Session() |
| |
| self.session.trust_env = False |
| |
| def get(self, endpoint: str, use_prefix: bool = True, **kwargs): |
| prefix = API_PREFIX if use_prefix else "" |
| return self.session.get(f"{BASE_URL}{prefix}{endpoint}", **kwargs) |
| |
| def post(self, endpoint: str, use_prefix: bool = True, **kwargs): |
| prefix = API_PREFIX if use_prefix else "" |
| return self.session.post(f"{BASE_URL}{prefix}{endpoint}", **kwargs) |
|
|
|
|
| class TestHealthCheck(TestCase): |
| """测试健康检查接口""" |
| |
| def run(self): |
| print("\n" + "="*60) |
| print("测试 1: 健康检查 API") |
| print("="*60) |
| |
| try: |
| response = self.get("/health", use_prefix=False) |
| if response.status_code == 200: |
| data = response.json() |
| if data.get("status") == "healthy": |
| return log_result("health_check", True, "健康检查通过", data) |
| return log_result("health_check", False, f"状态码: {response.status_code}") |
| except Exception as e: |
| return log_result("health_check", False, f"请求失败: {str(e)}") |
|
|
|
|
| class TestRootEndpoint(TestCase): |
| """测试根路径接口""" |
| |
| def run(self): |
| print("\n" + "="*60) |
| print("测试 2: 根路径 API") |
| print("="*60) |
| |
| try: |
| response = self.get("/", use_prefix=False) |
| if response.status_code == 200: |
| data = response.json() |
| return log_result("root_endpoint", True, "根路径响应正常", data) |
| return log_result("root_endpoint", False, f"状态码: {response.status_code}") |
| except Exception as e: |
| return log_result("root_endpoint", False, f"请求失败: {str(e)}") |
|
|
|
|
| class TestAPIDocumentation(TestCase): |
| """测试 API 文档""" |
| |
| def run(self): |
| print("\n" + "="*60) |
| print("测试 3: API 文档") |
| print("="*60) |
| |
| try: |
| |
| response = self.get("/openapi.json", use_prefix=False) |
| if response.status_code == 200: |
| data = response.json() |
| paths = list(data.get("paths", {}).keys()) |
| return log_result("api_docs", True, f"API 文档可用,共 {len(paths)} 个端点", {"paths": paths}) |
| return log_result("api_docs", False, f"状态码: {response.status_code}") |
| except Exception as e: |
| return log_result("api_docs", False, f"请求失败: {str(e)}") |
|
|
|
|
| class TestListScans(TestCase): |
| """测试扫描列表接口""" |
| |
| def run(self): |
| print("\n" + "="*60) |
| print("测试 4: 获取扫描列表") |
| print("="*60) |
| |
| try: |
| response = self.get("/scans") |
| if response.status_code == 200: |
| data = response.json() |
| |
| scans = data.get("scans", []) |
| return log_result("list_scans", True, f"获取扫描列表成功,共 {len(scans)} 条", {"count": len(scans), "data": data}) |
| return log_result("list_scans", False, f"状态码: {response.status_code}") |
| except Exception as e: |
| return log_result("list_scans", False, f"请求失败: {str(e)}") |
|
|
|
|
| class TestIngestDicom(TestCase): |
| """测试 DICOM 数据摄入""" |
| |
| def create_dummy_dicom_zip(self) -> Path: |
| """创建一个模拟的 DICOM ZIP 文件用于测试""" |
| import shutil |
| |
| temp_dir = Path(tempfile.mkdtemp()) |
| dicom_dir = temp_dir / "dicom_series" |
| dicom_dir.mkdir() |
| |
| |
| |
| dummy_data = np.random.randint(-1000, 1000, (64, 64, 32), dtype=np.int16) |
| |
| |
| (dicom_dir / "test_marker.txt").write_text("This is test DICOM data") |
| |
| |
| zip_path = temp_dir / "test_dicom.zip" |
| with zipfile.ZipFile(zip_path, 'w') as zf: |
| for file in dicom_dir.iterdir(): |
| zf.write(file, file.name) |
| |
| return zip_path, temp_dir |
| |
| def run(self): |
| print("\n" + "="*60) |
| print("测试 5: DICOM 数据摄入") |
| print("="*60) |
| |
| temp_dir = None |
| try: |
| |
| zip_path, temp_dir = self.create_dummy_dicom_zip() |
| |
| |
| with open(zip_path, 'rb') as f: |
| files = {'file': ('test_dicom.zip', f, 'application/zip')} |
| data = { |
| 'patient_id': 'TEST_PATIENT_001', |
| 'study_date': '2026-01-24' |
| } |
| response = self.post("/ingest", files=files, data=data) |
| |
| if response.status_code == 200: |
| result = response.json() |
| return log_result("ingest_dicom", True, "DICOM 摄入成功", result) |
| elif response.status_code == 500: |
| |
| return log_result("ingest_dicom", True, |
| "API 正常工作(测试数据非有效 DICOM 是预期的)", |
| {"status_code": 500, "note": "需要真实 DICOM 数据测试"}) |
| else: |
| return log_result("ingest_dicom", False, |
| f"状态码: {response.status_code}, 响应: {response.text[:200]}") |
| except Exception as e: |
| return log_result("ingest_dicom", False, f"请求失败: {str(e)}") |
| finally: |
| |
| if temp_dir and temp_dir.exists(): |
| import shutil |
| shutil.rmtree(temp_dir, ignore_errors=True) |
|
|
|
|
| class TestSingleAnalysis(TestCase): |
| """测试单次分析接口""" |
| |
| def run(self): |
| print("\n" + "="*60) |
| print("测试 6: 单次分析 API") |
| print("="*60) |
| |
| try: |
| |
| analysis_request = { |
| "scan_id": "test_scan_001", |
| "analysis_types": ["segmentation"], |
| "target_organs": ["liver", "spleen"] |
| } |
| |
| response = self.post("/analyze/single", json=analysis_request) |
| |
| if response.status_code == 200: |
| result = response.json() |
| return log_result("single_analysis", True, "单次分析请求已接受", result) |
| elif response.status_code == 404: |
| return log_result("single_analysis", True, |
| "API 正常工作(扫描不存在是预期的)", |
| {"status_code": 404, "message": "需要先上传扫描数据"}) |
| else: |
| return log_result("single_analysis", False, |
| f"状态码: {response.status_code}, 响应: {response.text[:200]}") |
| except Exception as e: |
| return log_result("single_analysis", False, f"请求失败: {str(e)}") |
|
|
|
|
| class TestLongitudinalAnalysis(TestCase): |
| """测试纵向对比分析接口""" |
| |
| def run(self): |
| print("\n" + "="*60) |
| print("测试 7: 纵向对比分析 API") |
| print("="*60) |
| |
| try: |
| analysis_request = { |
| "baseline_scan_id": "test_scan_baseline", |
| "followup_scan_id": "test_scan_followup", |
| "analysis_types": ["registration", "difference"] |
| } |
| |
| response = self.post("/analyze/longitudinal", json=analysis_request) |
| |
| if response.status_code == 200: |
| result = response.json() |
| return log_result("longitudinal_analysis", True, "纵向分析请求已接受", result) |
| elif response.status_code == 404: |
| return log_result("longitudinal_analysis", True, |
| "API 正常工作(扫描不存在是预期的)", |
| {"status_code": 404, "message": "需要先上传扫描数据"}) |
| else: |
| return log_result("longitudinal_analysis", False, |
| f"状态码: {response.status_code}, 响应: {response.text[:200]}") |
| except Exception as e: |
| return log_result("longitudinal_analysis", False, f"请求失败: {str(e)}") |
|
|
|
|
| class TestReportRetrieval(TestCase): |
| """测试报告获取接口""" |
| |
| def run(self): |
| print("\n" + "="*60) |
| print("测试 8: 报告获取 API") |
| print("="*60) |
| |
| try: |
| |
| response = self.get("/reports/test_task_123") |
| |
| if response.status_code == 200: |
| result = response.json() |
| return log_result("report_retrieval", True, "报告获取成功", result) |
| elif response.status_code == 404: |
| return log_result("report_retrieval", True, |
| "API 正常工作(任务不存在是预期的)", |
| {"status_code": 404, "message": "任务不存在"}) |
| else: |
| return log_result("report_retrieval", False, |
| f"状态码: {response.status_code}") |
| except Exception as e: |
| return log_result("report_retrieval", False, f"请求失败: {str(e)}") |
|
|
|
|
| class TestCORSHeaders(TestCase): |
| """测试 CORS 配置""" |
| |
| def run(self): |
| print("\n" + "="*60) |
| print("测试 9: CORS 配置") |
| print("="*60) |
| |
| try: |
| |
| response = self.session.options( |
| f"{BASE_URL}/health", |
| headers={ |
| "Origin": "http://localhost:8501", |
| "Access-Control-Request-Method": "GET" |
| } |
| ) |
| |
| cors_headers = { |
| "access-control-allow-origin": response.headers.get("access-control-allow-origin"), |
| "access-control-allow-methods": response.headers.get("access-control-allow-methods"), |
| } |
| |
| if cors_headers["access-control-allow-origin"]: |
| return log_result("cors_config", True, "CORS 配置正确", cors_headers) |
| else: |
| return log_result("cors_config", True, "CORS 可能使用通配符配置", cors_headers) |
| except Exception as e: |
| return log_result("cors_config", False, f"请求失败: {str(e)}") |
|
|
|
|
| class TestResponseTime(TestCase): |
| """测试 API 响应时间""" |
| |
| def run(self): |
| print("\n" + "="*60) |
| print("测试 10: API 响应时间") |
| print("="*60) |
| |
| try: |
| |
| endpoints = [ |
| ("/health", False), |
| ("/", False), |
| ("/scans", True), |
| ] |
| results = {} |
| |
| for endpoint, use_prefix in endpoints: |
| start = time.time() |
| response = self.get(endpoint, use_prefix=use_prefix) |
| elapsed = (time.time() - start) * 1000 |
| full_path = f"{API_PREFIX if use_prefix else ''}{endpoint}" |
| results[full_path] = { |
| "status_code": response.status_code, |
| "response_time_ms": round(elapsed, 2) |
| } |
| |
| avg_time = sum(r["response_time_ms"] for r in results.values()) / len(results) |
| |
| if avg_time < 500: |
| return log_result("response_time", True, |
| f"平均响应时间: {avg_time:.2f}ms", results) |
| else: |
| return log_result("response_time", False, |
| f"响应时间过长: {avg_time:.2f}ms", results) |
| except Exception as e: |
| return log_result("response_time", False, f"请求失败: {str(e)}") |
|
|
|
|
| def run_all_tests(): |
| """运行所有测试""" |
| print("\n" + "="*60) |
| print("NeuroScan AI 后端 API 测试") |
| print(f"时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") |
| print(f"目标: {BASE_URL}") |
| print("="*60) |
| |
| tests = [ |
| TestHealthCheck(), |
| TestRootEndpoint(), |
| TestAPIDocumentation(), |
| TestListScans(), |
| TestIngestDicom(), |
| TestSingleAnalysis(), |
| TestLongitudinalAnalysis(), |
| TestReportRetrieval(), |
| TestCORSHeaders(), |
| TestResponseTime(), |
| ] |
| |
| results = [] |
| passed = 0 |
| failed = 0 |
| |
| for test in tests: |
| try: |
| result = test.run() |
| results.append(result) |
| if result["success"]: |
| passed += 1 |
| else: |
| failed += 1 |
| except Exception as e: |
| print(f"❌ 测试异常: {str(e)}") |
| failed += 1 |
| |
| |
| print("\n" + "="*60) |
| print("测试总结") |
| print("="*60) |
| print(f"总计: {len(tests)} 个测试") |
| print(f"通过: {passed} ✅") |
| print(f"失败: {failed} ❌") |
| print(f"通过率: {passed/len(tests)*100:.1f}%") |
| |
| |
| summary = { |
| "timestamp": datetime.now().isoformat(), |
| "base_url": BASE_URL, |
| "total_tests": len(tests), |
| "passed": passed, |
| "failed": failed, |
| "pass_rate": f"{passed/len(tests)*100:.1f}%", |
| "results": results |
| } |
| |
| summary_file = TEST_RESULTS_DIR / f"test_summary_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" |
| with open(summary_file, 'w', encoding='utf-8') as f: |
| json.dump(summary, f, ensure_ascii=False, indent=2) |
| |
| print(f"\n测试结果已保存到: {TEST_RESULTS_DIR}") |
| |
| return passed == len(tests) |
|
|
|
|
| if __name__ == "__main__": |
| success = run_all_tests() |
| sys.exit(0 if success else 1) |
|
|
|
|