|
|
|
|
|
""" |
|
|
Comprehensive API test script for YLFF endpoints. |
|
|
|
|
|
Tests all API endpoints including: |
|
|
- Health and system endpoints |
|
|
- Validation endpoints (sequence, ARKit) |
|
|
- Training endpoints (fine-tuning, pre-training) with optimization parameters |
|
|
- Dataset building with optimization parameters |
|
|
- Job management |
|
|
- Profiling endpoints |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import logging |
|
|
import sys |
|
|
import time |
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, List, Optional |
|
|
import requests |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format="%(asctime)s - %(levelname)s - %(message)s", |
|
|
datefmt="%H:%M:%S", |
|
|
stream=sys.stdout, |
|
|
force=True, |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class APITester: |
|
|
"""API testing utility class.""" |
|
|
|
|
|
def __init__(self, base_url: str, timeout: int = 300): |
|
|
self.base_url = base_url.rstrip("/") |
|
|
self.timeout = timeout |
|
|
self.results: List[tuple[str, Dict[str, Any]]] = [] |
|
|
self.job_ids: List[str] = [] |
|
|
|
|
|
def test_endpoint( |
|
|
self, |
|
|
method: str, |
|
|
endpoint: str, |
|
|
description: str = "", |
|
|
**kwargs, |
|
|
) -> Dict[str, Any]: |
|
|
"""Test a single endpoint.""" |
|
|
url = f"{self.base_url}{endpoint}" |
|
|
desc = f" ({description})" if description else "" |
|
|
logger.info(f"→ {method} {endpoint}{desc}") |
|
|
|
|
|
try: |
|
|
start_time = time.time() |
|
|
response = requests.request(method, url, timeout=self.timeout, **kwargs) |
|
|
duration = time.time() - start_time |
|
|
|
|
|
logger.info(f"← {response.status_code} ({duration:.3f}s)") |
|
|
|
|
|
try: |
|
|
data = response.json() if response.content else None |
|
|
except json.JSONDecodeError: |
|
|
data = response.text |
|
|
|
|
|
result = { |
|
|
"status_code": response.status_code, |
|
|
"data": data, |
|
|
"duration": duration, |
|
|
"success": 200 <= response.status_code < 300, |
|
|
} |
|
|
|
|
|
|
|
|
if result.get("success") and data and isinstance(data, dict): |
|
|
job_id = data.get("job_id") |
|
|
if job_id: |
|
|
self.job_ids.append(job_id) |
|
|
logger.info(f" Job ID: {job_id}") |
|
|
|
|
|
return result |
|
|
except requests.exceptions.RequestException as e: |
|
|
logger.error(f"✗ Request failed: {e}") |
|
|
return {"status_code": None, "error": str(e), "success": False} |
|
|
|
|
|
def test_health_endpoints(self): |
|
|
"""Test health and system endpoints.""" |
|
|
logger.info("\n" + "=" * 80) |
|
|
logger.info("HEALTH & SYSTEM ENDPOINTS") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
|
|
|
result = self.test_endpoint("GET", "/health", "Health check") |
|
|
self.results.append(("GET /health", result)) |
|
|
|
|
|
|
|
|
result = self.test_endpoint("GET", "/", "Root endpoint") |
|
|
self.results.append(("GET /", result)) |
|
|
|
|
|
|
|
|
result = self.test_endpoint("GET", "/api/v1/models", "List models") |
|
|
self.results.append(("GET /api/v1/models", result)) |
|
|
|
|
|
|
|
|
result = self.test_endpoint("GET", "/api/v1/jobs", "List jobs") |
|
|
self.results.append(("GET /api/v1/jobs", result)) |
|
|
|
|
|
def test_profiling_endpoints(self): |
|
|
"""Test profiling endpoints.""" |
|
|
logger.info("\n" + "=" * 80) |
|
|
logger.info("PROFILING ENDPOINTS") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
endpoints = [ |
|
|
("/api/v1/profiling/metrics", "Profiling metrics"), |
|
|
("/api/v1/profiling/hot-paths", "Hot paths"), |
|
|
("/api/v1/profiling/latency", "Latency breakdown"), |
|
|
("/api/v1/profiling/system", "System metrics"), |
|
|
] |
|
|
|
|
|
for endpoint, desc in endpoints: |
|
|
result = self.test_endpoint("GET", endpoint, desc) |
|
|
self.results.append((f"GET {endpoint}", result)) |
|
|
|
|
|
def test_validation_endpoints( |
|
|
self, sequence_dir: Optional[str] = None, arkit_dir: Optional[str] = None |
|
|
): |
|
|
"""Test validation endpoints.""" |
|
|
logger.info("\n" + "=" * 80) |
|
|
logger.info("VALIDATION ENDPOINTS") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
|
|
|
if sequence_dir: |
|
|
payload = { |
|
|
"sequence_dir": sequence_dir, |
|
|
"use_case": "ba_validation", |
|
|
"accept_threshold": 2.0, |
|
|
"reject_threshold": 30.0, |
|
|
} |
|
|
result = self.test_endpoint( |
|
|
"POST", |
|
|
"/api/v1/validate/sequence", |
|
|
f"Validate sequence: {sequence_dir}", |
|
|
json=payload, |
|
|
) |
|
|
self.results.append(("POST /api/v1/validate/sequence", result)) |
|
|
else: |
|
|
logger.info("Skipping /api/v1/validate/sequence (no sequence_dir)") |
|
|
|
|
|
|
|
|
if arkit_dir: |
|
|
payload = { |
|
|
"arkit_dir": arkit_dir, |
|
|
"output_dir": "data/test_arkit_output", |
|
|
"max_frames": 10, |
|
|
"frame_interval": 1, |
|
|
"device": "cuda", |
|
|
"gui": False, |
|
|
} |
|
|
result = self.test_endpoint( |
|
|
"POST", |
|
|
"/api/v1/validate/arkit", |
|
|
f"Validate ARKit: {arkit_dir}", |
|
|
json=payload, |
|
|
) |
|
|
self.results.append(("POST /api/v1/validate/arkit", result)) |
|
|
else: |
|
|
logger.info("Skipping /api/v1/validate/arkit (no arkit_dir)") |
|
|
|
|
|
def test_dataset_building_endpoints( |
|
|
self, |
|
|
sequences_dir: Optional[str] = None, |
|
|
test_optimizations: bool = True, |
|
|
): |
|
|
"""Test dataset building endpoint with optimizations.""" |
|
|
logger.info("\n" + "=" * 80) |
|
|
logger.info("DATASET BUILDING ENDPOINTS") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
if not sequences_dir: |
|
|
logger.info("Skipping /api/v1/dataset/build (no sequences_dir)") |
|
|
return |
|
|
|
|
|
|
|
|
if test_optimizations: |
|
|
payload = { |
|
|
"sequences_dir": sequences_dir, |
|
|
"output_dir": "data/test_training", |
|
|
"max_samples": 10, |
|
|
"accept_threshold": 2.0, |
|
|
"reject_threshold": 30.0, |
|
|
"use_batched_inference": True, |
|
|
"inference_batch_size": 4, |
|
|
"use_inference_cache": True, |
|
|
"cache_dir": "cache/test_inference", |
|
|
"compile_model": True, |
|
|
} |
|
|
result = self.test_endpoint( |
|
|
"POST", |
|
|
"/api/v1/dataset/build", |
|
|
"Build dataset with optimizations", |
|
|
json=payload, |
|
|
) |
|
|
self.results.append(("POST /api/v1/dataset/build (optimized)", result)) |
|
|
|
|
|
|
|
|
payload = { |
|
|
"sequences_dir": sequences_dir, |
|
|
"output_dir": "data/test_training_baseline", |
|
|
"max_samples": 10, |
|
|
"accept_threshold": 2.0, |
|
|
"reject_threshold": 30.0, |
|
|
"use_batched_inference": False, |
|
|
"use_inference_cache": False, |
|
|
"compile_model": False, |
|
|
} |
|
|
result = self.test_endpoint( |
|
|
"POST", |
|
|
"/api/v1/dataset/build", |
|
|
"Build dataset (baseline)", |
|
|
json=payload, |
|
|
) |
|
|
self.results.append(("POST /api/v1/dataset/build (baseline)", result)) |
|
|
|
|
|
def test_training_endpoints( |
|
|
self, |
|
|
training_data_dir: Optional[str] = None, |
|
|
test_optimizations: bool = True, |
|
|
): |
|
|
"""Test training endpoints with optimization parameters.""" |
|
|
logger.info("\n" + "=" * 80) |
|
|
logger.info("TRAINING ENDPOINTS") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
if not training_data_dir: |
|
|
logger.info("Skipping /api/v1/train/start (no training_data_dir)") |
|
|
return |
|
|
|
|
|
|
|
|
if test_optimizations: |
|
|
payload = { |
|
|
"training_data_dir": training_data_dir, |
|
|
"epochs": 1, |
|
|
"lr": 1e-5, |
|
|
"batch_size": 1, |
|
|
"checkpoint_dir": "checkpoints/test", |
|
|
"device": "cuda", |
|
|
"use_wandb": False, |
|
|
|
|
|
"gradient_accumulation_steps": 4, |
|
|
"use_amp": True, |
|
|
"warmup_steps": 10, |
|
|
"num_workers": 2, |
|
|
"use_ema": True, |
|
|
"ema_decay": 0.9999, |
|
|
"use_onecycle": False, |
|
|
"use_gradient_checkpointing": False, |
|
|
"compile_model": True, |
|
|
|
|
|
"use_bf16": False, |
|
|
"gradient_clip_norm": 1.0, |
|
|
"find_lr": False, |
|
|
"find_batch_size": False, |
|
|
|
|
|
"use_fsdp": False, |
|
|
"fsdp_sharding_strategy": "FULL_SHARD", |
|
|
"fsdp_mixed_precision": None, |
|
|
|
|
|
"use_qat": False, |
|
|
"qat_backend": "fbgemm", |
|
|
"use_sequence_parallel": False, |
|
|
"sequence_parallel_gpus": 1, |
|
|
"activation_recompute_strategy": None, |
|
|
|
|
|
"async_checkpoint": True, |
|
|
"compress_checkpoint": True, |
|
|
} |
|
|
result = self.test_endpoint( |
|
|
"POST", |
|
|
"/api/v1/train/start", |
|
|
"Fine-tune with optimizations", |
|
|
json=payload, |
|
|
) |
|
|
self.results.append(("POST /api/v1/train/start (optimized)", result)) |
|
|
|
|
|
|
|
|
payload = { |
|
|
"training_data_dir": training_data_dir, |
|
|
"epochs": 1, |
|
|
"lr": 1e-5, |
|
|
"batch_size": 1, |
|
|
"checkpoint_dir": "checkpoints/test_baseline", |
|
|
"device": "cuda", |
|
|
"use_wandb": False, |
|
|
"gradient_accumulation_steps": 1, |
|
|
"use_amp": False, |
|
|
"compile_model": False, |
|
|
} |
|
|
result = self.test_endpoint( |
|
|
"POST", |
|
|
"/api/v1/train/start", |
|
|
"Fine-tune (baseline)", |
|
|
json=payload, |
|
|
) |
|
|
self.results.append(("POST /api/v1/train/start (baseline)", result)) |
|
|
|
|
|
def test_pretraining_endpoints( |
|
|
self, |
|
|
arkit_sequences_dir: Optional[str] = None, |
|
|
test_optimizations: bool = True, |
|
|
): |
|
|
"""Test pre-training endpoints with optimization parameters.""" |
|
|
logger.info("\n" + "=" * 80) |
|
|
logger.info("PRE-TRAINING ENDPOINTS") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
if not arkit_sequences_dir: |
|
|
logger.info("Skipping /api/v1/train/pretrain (no arkit_sequences_dir)") |
|
|
return |
|
|
|
|
|
|
|
|
if test_optimizations: |
|
|
payload = { |
|
|
"arkit_sequences_dir": arkit_sequences_dir, |
|
|
"epochs": 1, |
|
|
"lr": 1e-4, |
|
|
"batch_size": 1, |
|
|
"checkpoint_dir": "checkpoints/test_pretrain", |
|
|
"device": "cuda", |
|
|
"max_sequences": 1, |
|
|
"max_frames_per_sequence": 10, |
|
|
"frame_interval": 1, |
|
|
"use_lidar": False, |
|
|
"use_ba_depth": False, |
|
|
"min_ba_quality": 0.0, |
|
|
"use_wandb": False, |
|
|
|
|
|
"gradient_accumulation_steps": 4, |
|
|
"use_amp": True, |
|
|
"warmup_steps": 10, |
|
|
"num_workers": 2, |
|
|
"use_ema": True, |
|
|
"ema_decay": 0.9999, |
|
|
"use_onecycle": False, |
|
|
"use_gradient_checkpointing": False, |
|
|
"compile_model": True, |
|
|
"cache_dir": "cache/test_ba", |
|
|
|
|
|
"use_bf16": False, |
|
|
"gradient_clip_norm": 1.0, |
|
|
"find_lr": False, |
|
|
"find_batch_size": False, |
|
|
|
|
|
"use_fsdp": False, |
|
|
"fsdp_sharding_strategy": "FULL_SHARD", |
|
|
"fsdp_mixed_precision": None, |
|
|
|
|
|
"use_qat": False, |
|
|
"qat_backend": "fbgemm", |
|
|
"use_sequence_parallel": False, |
|
|
"sequence_parallel_gpus": 1, |
|
|
"activation_recompute_strategy": None, |
|
|
|
|
|
"async_checkpoint": True, |
|
|
"compress_checkpoint": True, |
|
|
} |
|
|
result = self.test_endpoint( |
|
|
"POST", |
|
|
"/api/v1/train/pretrain", |
|
|
"Pre-train with optimizations", |
|
|
json=payload, |
|
|
) |
|
|
self.results.append(("POST /api/v1/train/pretrain (optimized)", result)) |
|
|
|
|
|
def poll_jobs(self, max_polls: int = 60, poll_interval: int = 5): |
|
|
"""Poll job status until completion.""" |
|
|
logger.info("\n" + "=" * 80) |
|
|
logger.info("POLLING JOBS") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
if not self.job_ids: |
|
|
logger.info("No jobs to monitor") |
|
|
return |
|
|
|
|
|
logger.info(f"Monitoring {len(self.job_ids)} job(s)") |
|
|
|
|
|
for job_id in self.job_ids: |
|
|
logger.info(f"\nMonitoring job: {job_id}") |
|
|
for poll_num in range(max_polls): |
|
|
result = self.test_endpoint( |
|
|
"GET", |
|
|
f"/api/v1/jobs/{job_id}", |
|
|
f"Job status (poll {poll_num + 1}/{max_polls})", |
|
|
) |
|
|
|
|
|
if result.get("success") and result.get("data"): |
|
|
data = result["data"] |
|
|
status = data.get("status", "unknown") |
|
|
message = data.get("message", "") |
|
|
logger.info(f" Status: {status}, Message: {message[:60]}") |
|
|
|
|
|
if status in ["completed", "failed"]: |
|
|
logger.info(f" Job {status}!") |
|
|
if status == "completed": |
|
|
job_result = data.get("result", {}) |
|
|
if job_result: |
|
|
logger.info(f" Result keys: {list(job_result.keys())}") |
|
|
break |
|
|
|
|
|
if poll_num < max_polls - 1: |
|
|
time.sleep(poll_interval) |
|
|
else: |
|
|
logger.warning(" Failed to get job status") |
|
|
break |
|
|
|
|
|
self.results.append((f"GET /api/v1/jobs/{job_id} (final)", result)) |
|
|
|
|
|
def print_summary(self): |
|
|
"""Print test summary.""" |
|
|
logger.info("\n" + "=" * 80) |
|
|
logger.info("TEST SUMMARY") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
success_count = sum(1 for _, r in self.results if r.get("success")) |
|
|
total_count = len(self.results) |
|
|
|
|
|
logger.info(f"Success: {success_count}/{total_count}") |
|
|
logger.info("") |
|
|
|
|
|
logger.info("Endpoint Results:") |
|
|
for endpoint, result in self.results: |
|
|
status = "✓" if result.get("success") else "✗" |
|
|
status_code = result.get("status_code", "N/A") |
|
|
duration = result.get("duration", 0) |
|
|
status_code_str = str(status_code) if status_code is not None else "N/A" |
|
|
logger.info(f"{status} {endpoint:60s} {status_code_str:>3} ({duration:.3f}s)") |
|
|
|
|
|
def save_results(self, output_file: Path): |
|
|
"""Save test results to JSON file.""" |
|
|
output_file.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
output_data = { |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"base_url": self.base_url, |
|
|
"summary": { |
|
|
"total_tests": len(self.results), |
|
|
"successful": sum(1 for _, r in self.results if r.get("success")), |
|
|
"failed": sum(1 for _, r in self.results if not r.get("success")), |
|
|
}, |
|
|
"results": [ |
|
|
{ |
|
|
"endpoint": endpoint, |
|
|
"status_code": r.get("status_code"), |
|
|
"success": r.get("success"), |
|
|
"duration": r.get("duration"), |
|
|
"data": r.get("data") if r.get("success") else None, |
|
|
"error": r.get("error") if not r.get("success") else None, |
|
|
} |
|
|
for endpoint, r in self.results |
|
|
], |
|
|
} |
|
|
|
|
|
with open(output_file, "w") as f: |
|
|
json.dump(output_data, f, indent=2, default=str) |
|
|
|
|
|
logger.info(f"\nResults saved to: {output_file}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main test function.""" |
|
|
parser = argparse.ArgumentParser(description="Comprehensive YLFF API endpoint testing") |
|
|
parser.add_argument( |
|
|
"--base-url", |
|
|
default="http://localhost:8000", |
|
|
help="Base URL for API", |
|
|
) |
|
|
parser.add_argument("--sequence-dir", type=str, help="Sequence directory for validation") |
|
|
parser.add_argument("--arkit-dir", type=str, help="ARKit directory for validation") |
|
|
parser.add_argument( |
|
|
"--sequences-dir", |
|
|
type=str, |
|
|
help="Sequences directory for dataset building", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--training-data-dir", |
|
|
type=str, |
|
|
help="Training data directory for fine-tuning", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--arkit-sequences-dir", |
|
|
type=str, |
|
|
help="ARKit sequences directory for pre-training", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--skip-optimizations", |
|
|
action="store_true", |
|
|
help="Skip optimization parameter tests", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--skip-polling", |
|
|
action="store_true", |
|
|
help="Skip job polling", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output", |
|
|
type=Path, |
|
|
default=Path("data/api_test_results.json"), |
|
|
help="Output file for results", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--timeout", |
|
|
type=int, |
|
|
default=300, |
|
|
help="Request timeout in seconds", |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
logger.info("=" * 80) |
|
|
logger.info("YLFF API COMPREHENSIVE TEST") |
|
|
logger.info("=" * 80) |
|
|
logger.info(f"Base URL: {args.base_url}") |
|
|
logger.info(f"Timeout: {args.timeout}s") |
|
|
logger.info("") |
|
|
|
|
|
tester = APITester(args.base_url, timeout=args.timeout) |
|
|
|
|
|
|
|
|
tester.test_health_endpoints() |
|
|
tester.test_profiling_endpoints() |
|
|
tester.test_validation_endpoints(sequence_dir=args.sequence_dir, arkit_dir=args.arkit_dir) |
|
|
tester.test_dataset_building_endpoints( |
|
|
sequences_dir=args.sequences_dir, |
|
|
test_optimizations=not args.skip_optimizations, |
|
|
) |
|
|
tester.test_training_endpoints( |
|
|
training_data_dir=args.training_data_dir, |
|
|
test_optimizations=not args.skip_optimizations, |
|
|
) |
|
|
tester.test_pretraining_endpoints( |
|
|
arkit_sequences_dir=args.arkit_sequences_dir, |
|
|
test_optimizations=not args.skip_optimizations, |
|
|
) |
|
|
|
|
|
|
|
|
if not args.skip_polling: |
|
|
tester.poll_jobs() |
|
|
|
|
|
|
|
|
tester.print_summary() |
|
|
|
|
|
|
|
|
tester.save_results(args.output) |
|
|
|
|
|
|
|
|
success_count = sum(1 for _, r in tester.results if r.get("success")) |
|
|
total_count = len(tester.results) |
|
|
return 0 if success_count == total_count else 1 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
sys.exit(main()) |
|
|
|