|
|
|
|
|
import os |
|
|
import json |
|
|
import logging |
|
|
import sys |
|
|
from fastapi import FastAPI, HTTPException, Request, Form |
|
|
from fastapi.responses import HTMLResponse, JSONResponse, FileResponse |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from fastapi.templating import Jinja2Templates |
|
|
from pydantic import BaseModel |
|
|
from datetime import datetime |
|
|
from datasets import Dataset, load_dataset, concatenate_datasets |
|
|
from typing import Dict, Optional, Any, List |
|
|
import uuid |
|
|
import re |
|
|
import html |
|
|
from urllib.parse import urlparse |
|
|
from starlette.middleware.base import BaseHTTPMiddleware |
|
|
from huggingface_hub import HfApi |
|
|
from huggingface_hub.utils import RepositoryNotFoundError |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
try: |
|
|
from src.aibom_generator.field_registry_manager import ( |
|
|
get_field_registry_manager, |
|
|
generate_field_classification, |
|
|
get_configurable_scoring_weights |
|
|
) |
|
|
REGISTRY_MANAGER = get_field_registry_manager() |
|
|
FIELD_CLASSIFICATION = generate_field_classification() |
|
|
SCORING_WEIGHTS = get_configurable_scoring_weights() |
|
|
REGISTRY_AVAILABLE = True |
|
|
logger.info(f"✅ Registry-driven API: {len(FIELD_CLASSIFICATION)} fields loaded") |
|
|
except ImportError as e: |
|
|
REGISTRY_AVAILABLE = False |
|
|
FIELD_CLASSIFICATION = {} |
|
|
SCORING_WEIGHTS = {} |
|
|
logger.warning(f"⚠️ Registry not available for API: {e}") |
|
|
|
|
|
|
|
|
templates_dir = "templates" |
|
|
OUTPUT_DIR = "/tmp/aibom_output" |
|
|
MAX_AGE_DAYS = 7 |
|
|
MAX_FILES = 1000 |
|
|
CLEANUP_INTERVAL = 100 |
|
|
|
|
|
|
|
|
HF_REPO = "aetheris-ai/aisbom-usage-log" |
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(title="AI SBOM Generator API") |
|
|
|
|
|
|
|
|
try: |
|
|
from src.aibom_generator.rate_limiting import RateLimitMiddleware, ConcurrencyLimitMiddleware, RequestSizeLimitMiddleware |
|
|
logger.info("Successfully imported rate_limiting from src.aibom_generator") |
|
|
except ImportError: |
|
|
try: |
|
|
from .rate_limiting import RateLimitMiddleware, ConcurrencyLimitMiddleware, RequestSizeLimitMiddleware |
|
|
logger.info("Successfully imported rate_limiting with relative import") |
|
|
except ImportError: |
|
|
try: |
|
|
from rate_limiting import RateLimitMiddleware, ConcurrencyLimitMiddleware, RequestSizeLimitMiddleware |
|
|
logger.info("Successfully imported rate_limiting from current directory") |
|
|
except ImportError: |
|
|
logger.error("Could not import rate_limiting, DoS protection disabled") |
|
|
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware): |
|
|
def __init__(self, app, **kwargs): |
|
|
super().__init__(app) |
|
|
async def dispatch(self, request, call_next): |
|
|
try: |
|
|
return await call_next(request) |
|
|
except Exception as e: |
|
|
logger.error(f"Error in RateLimitMiddleware: {str(e)}") |
|
|
return JSONResponse( |
|
|
status_code=500, |
|
|
content={"detail": f"Internal server error: {str(e)}"} |
|
|
) |
|
|
|
|
|
class ConcurrencyLimitMiddleware(BaseHTTPMiddleware): |
|
|
def __init__(self, app, **kwargs): |
|
|
super().__init__(app) |
|
|
async def dispatch(self, request, call_next): |
|
|
try: |
|
|
return await call_next(request) |
|
|
except Exception as e: |
|
|
logger.error(f"Error in ConcurrencyLimitMiddleware: {str(e)}") |
|
|
return JSONResponse( |
|
|
status_code=500, |
|
|
content={"detail": f"Internal server error: {str(e)}"} |
|
|
) |
|
|
|
|
|
class RequestSizeLimitMiddleware(BaseHTTPMiddleware): |
|
|
def __init__(self, app, **kwargs): |
|
|
super().__init__(app) |
|
|
async def dispatch(self, request, call_next): |
|
|
try: |
|
|
return await call_next(request) |
|
|
except Exception as e: |
|
|
logger.error(f"Error in RequestSizeLimitMiddleware: {str(e)}") |
|
|
return JSONResponse( |
|
|
status_code=500, |
|
|
content={"detail": f"Internal server error: {str(e)}"} |
|
|
) |
|
|
try: |
|
|
from src.aibom_generator.captcha import verify_recaptcha |
|
|
logger.info("Successfully imported captcha from src.aibom_generator") |
|
|
except ImportError: |
|
|
try: |
|
|
from .captcha import verify_recaptcha |
|
|
logger.info("Successfully imported captcha with relative import") |
|
|
except ImportError: |
|
|
try: |
|
|
from captcha import verify_recaptcha |
|
|
logger.info("Successfully imported captcha from current directory") |
|
|
except ImportError: |
|
|
logger.warning("Could not import captcha module, CAPTCHA verification disabled") |
|
|
|
|
|
def verify_recaptcha(response_token=None): |
|
|
logger.warning("Using dummy CAPTCHA verification (always succeeds)") |
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
RateLimitMiddleware, |
|
|
rate_limit_per_minute=10, |
|
|
rate_limit_window=60, |
|
|
protected_routes=["/generate", "/api/generate", "/api/generate-with-report"] |
|
|
) |
|
|
|
|
|
app.add_middleware( |
|
|
ConcurrencyLimitMiddleware, |
|
|
max_concurrent_requests=5, |
|
|
timeout=5.0, |
|
|
protected_routes=["/generate", "/api/generate", "/api/generate-with-report"] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
RequestSizeLimitMiddleware, |
|
|
max_content_length=1024*1024 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class StatusResponse(BaseModel): |
|
|
status: str |
|
|
version: str |
|
|
generator_version: str |
|
|
|
|
|
|
|
|
templates = Jinja2Templates(directory=templates_dir) |
|
|
|
|
|
|
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
app.mount("/output", StaticFiles(directory=OUTPUT_DIR), name="output") |
|
|
|
|
|
|
|
|
request_counter = 0 |
|
|
|
|
|
|
|
|
try: |
|
|
from src.aibom_generator.cleanup_utils import perform_cleanup |
|
|
logger.info("Successfully imported cleanup_utils") |
|
|
except ImportError: |
|
|
try: |
|
|
from cleanup_utils import perform_cleanup |
|
|
logger.info("Successfully imported cleanup_utils from current directory") |
|
|
except ImportError: |
|
|
logger.error("Could not import cleanup_utils, defining functions inline") |
|
|
|
|
|
def cleanup_old_files(directory, max_age_days=7): |
|
|
"""Remove files older than max_age_days from the specified directory.""" |
|
|
if not os.path.exists(directory): |
|
|
logger.warning(f"Directory does not exist: {directory}") |
|
|
return 0 |
|
|
|
|
|
removed_count = 0 |
|
|
now = datetime.now() |
|
|
|
|
|
try: |
|
|
for filename in os.listdir(directory): |
|
|
file_path = os.path.join(directory, filename) |
|
|
if os.path.isfile(file_path): |
|
|
file_age = now - datetime.fromtimestamp(os.path.getmtime(file_path)) |
|
|
if file_age.days > max_age_days: |
|
|
try: |
|
|
os.remove(file_path) |
|
|
removed_count += 1 |
|
|
logger.info(f"Removed old file: {file_path}") |
|
|
except Exception as e: |
|
|
logger.error(f"Error removing file {file_path}: {e}") |
|
|
|
|
|
logger.info(f"Cleanup completed: removed {removed_count} files older than {max_age_days} days from {directory}") |
|
|
return removed_count |
|
|
except Exception as e: |
|
|
logger.error(f"Error during cleanup of directory {directory}: {e}") |
|
|
return 0 |
|
|
|
|
|
def limit_file_count(directory, max_files=1000): |
|
|
"""Ensure no more than max_files are kept in the directory (removes oldest first).""" |
|
|
if not os.path.exists(directory): |
|
|
logger.warning(f"Directory does not exist: {directory}") |
|
|
return 0 |
|
|
|
|
|
removed_count = 0 |
|
|
|
|
|
try: |
|
|
files = [] |
|
|
for filename in os.listdir(directory): |
|
|
file_path = os.path.join(directory, filename) |
|
|
if os.path.isfile(file_path): |
|
|
files.append((file_path, os.path.getmtime(file_path))) |
|
|
|
|
|
|
|
|
files.sort(key=lambda x: x[1]) |
|
|
|
|
|
|
|
|
files_to_remove = files[:-max_files] if len(files) > max_files else [] |
|
|
|
|
|
for file_path, _ in files_to_remove: |
|
|
try: |
|
|
os.remove(file_path) |
|
|
removed_count += 1 |
|
|
logger.info(f"Removed excess file: {file_path}") |
|
|
except Exception as e: |
|
|
logger.error(f"Error removing file {file_path}: {e}") |
|
|
|
|
|
logger.info(f"File count limit enforced: removed {removed_count} oldest files from {directory}, keeping max {max_files}") |
|
|
return removed_count |
|
|
except Exception as e: |
|
|
logger.error(f"Error during file count limiting in directory {directory}: {e}") |
|
|
return 0 |
|
|
|
|
|
def perform_cleanup(directory, max_age_days=7, max_files=1000): |
|
|
"""Perform both time-based and count-based cleanup.""" |
|
|
time_removed = cleanup_old_files(directory, max_age_days) |
|
|
count_removed = limit_file_count(directory, max_files) |
|
|
return time_removed + count_removed |
|
|
|
|
|
|
|
|
try: |
|
|
removed = perform_cleanup(OUTPUT_DIR, MAX_AGE_DAYS, MAX_FILES) |
|
|
logger.info(f"Initial cleanup removed {removed} files") |
|
|
except Exception as e: |
|
|
logger.error(f"Error during initial cleanup: {e}") |
|
|
|
|
|
|
|
|
@app.middleware("http" ) |
|
|
async def cleanup_middleware(request, call_next): |
|
|
"""Middleware to periodically run cleanup.""" |
|
|
global request_counter |
|
|
|
|
|
|
|
|
request_counter += 1 |
|
|
|
|
|
|
|
|
if request_counter % CLEANUP_INTERVAL == 0: |
|
|
logger.info(f"Running scheduled cleanup after {request_counter} requests") |
|
|
try: |
|
|
removed = perform_cleanup(OUTPUT_DIR, MAX_AGE_DAYS, MAX_FILES) |
|
|
logger.info(f"Scheduled cleanup removed {removed} files") |
|
|
except Exception as e: |
|
|
logger.error(f"Error during scheduled cleanup: {e}") |
|
|
|
|
|
|
|
|
response = await call_next(request) |
|
|
return response |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
HF_ID_REGEX = re.compile(r"^[a-zA-Z0-9\.\-\_]+/[a-zA-Z0-9\.\-\_]+$") |
|
|
|
|
|
def is_valid_hf_input(input_str: str) -> bool: |
|
|
"""Checks if the input is a valid Hugging Face model ID or URL.""" |
|
|
if not input_str or len(input_str) > 200: |
|
|
return False |
|
|
|
|
|
if input_str.startswith(("http://", "https://") ): |
|
|
try: |
|
|
parsed = urlparse(input_str) |
|
|
|
|
|
if parsed.netloc == "huggingface.co": |
|
|
path_parts = parsed.path.strip("/").split("/") |
|
|
|
|
|
if len(path_parts) >= 2 and path_parts[0] and path_parts[1]: |
|
|
|
|
|
if re.match(r"^[a-zA-Z0-9\.\-\_]+$", path_parts[0]) and \ |
|
|
re.match(r"^[a-zA-Z0-9\.\-\_]+$", path_parts[1]): |
|
|
return True |
|
|
return False |
|
|
except Exception: |
|
|
return False |
|
|
else: |
|
|
|
|
|
return bool(HF_ID_REGEX.match(input_str)) |
|
|
|
|
|
def _normalise_model_id(raw_id: str) -> str: |
|
|
""" |
|
|
Accept either validated 'owner/model' or a validated full URL like |
|
|
'https://huggingface.co/owner/model'. Return 'owner/model'. |
|
|
Assumes input has already been validated by is_valid_hf_input. |
|
|
""" |
|
|
if raw_id.startswith(("http://", "https://") ): |
|
|
path = urlparse(raw_id).path.lstrip("/") |
|
|
parts = path.split("/") |
|
|
|
|
|
return f"{parts[0]}/{parts[1]}" |
|
|
return raw_id |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def log_sbom_generation(model_id: str): |
|
|
"""Logs a successful SBOM generation event to the Hugging Face dataset.""" |
|
|
if not HF_TOKEN: |
|
|
logger.warning("HF_TOKEN not set. Skipping SBOM generation logging.") |
|
|
return |
|
|
|
|
|
try: |
|
|
|
|
|
normalized_model_id_for_log = _normalise_model_id(model_id) |
|
|
log_data = { |
|
|
"timestamp": [datetime.utcnow().isoformat()], |
|
|
"event": ["generated"], |
|
|
"model_id": [normalized_model_id_for_log] |
|
|
} |
|
|
ds_new_log = Dataset.from_dict(log_data) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
existing_ds = load_dataset(HF_REPO, token=HF_TOKEN, split='train', trust_remote_code=True) |
|
|
|
|
|
if len(existing_ds) > 0 and set(existing_ds.column_names) == set(log_data.keys()): |
|
|
ds_to_push = concatenate_datasets([existing_ds, ds_new_log]) |
|
|
elif len(existing_ds) == 0: |
|
|
logger.info(f"Dataset {HF_REPO} is empty. Pushing initial data.") |
|
|
ds_to_push = ds_new_log |
|
|
else: |
|
|
logger.warning(f"Dataset {HF_REPO} has unexpected columns {existing_ds.column_names} vs {list(log_data.keys())}. Appending new log anyway, structure might differ.") |
|
|
|
|
|
|
|
|
ds_to_push = concatenate_datasets([existing_ds, ds_new_log]) |
|
|
|
|
|
except Exception as load_err: |
|
|
|
|
|
|
|
|
logger.info(f"Could not load existing dataset {HF_REPO} (may not exist yet): {load_err}. Pushing new dataset.") |
|
|
ds_to_push = ds_new_log |
|
|
|
|
|
|
|
|
|
|
|
ds_to_push.push_to_hub(HF_REPO, token=HF_TOKEN, private=True) |
|
|
logger.info(f"Successfully logged SBOM generation for {normalized_model_id_for_log} to {HF_REPO}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to log SBOM generation to {HF_REPO}: {e}") |
|
|
|
|
|
def get_sbom_count() -> str: |
|
|
"""Retrieves the total count of generated SBOMs from the Hugging Face dataset.""" |
|
|
if not HF_TOKEN: |
|
|
logger.warning("HF_TOKEN not set. Cannot retrieve SBOM count.") |
|
|
return "N/A" |
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
ds = load_dataset(HF_REPO, token=HF_TOKEN, split='train', trust_remote_code=True) |
|
|
count = len(ds) |
|
|
logger.info(f"Retrieved SBOM count: {count} from {HF_REPO}") |
|
|
|
|
|
return f"{count:,}" |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to retrieve SBOM count from {HF_REPO}: {e}") |
|
|
|
|
|
return "N/A" |
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
logger.info(f"Output directory ready at {OUTPUT_DIR}") |
|
|
logger.info(f"Registered routes: {[route.path for route in app.routes]}") |
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
|
async def root(request: Request): |
|
|
sbom_count = get_sbom_count() |
|
|
try: |
|
|
return templates.TemplateResponse("index.html", {"request": request, "sbom_count": sbom_count}) |
|
|
except Exception as e: |
|
|
logger.error(f"Error rendering template: {str(e)}") |
|
|
|
|
|
try: |
|
|
return templates.TemplateResponse("error.html", {"request": request, "error": f"Template rendering error: {str(e)}", "sbom_count": sbom_count}) |
|
|
except Exception as template_err: |
|
|
|
|
|
logger.error(f"Error rendering error template: {template_err}") |
|
|
raise HTTPException(status_code=500, detail=f"Template rendering error: {str(e)}") |
|
|
|
|
|
@app.get("/status", response_model=StatusResponse) |
|
|
async def get_status(): |
|
|
return StatusResponse(status="operational", version="1.0.0", generator_version="1.0.0") |
|
|
|
|
|
|
|
|
def import_utils(): |
|
|
"""Import utils module with fallback paths.""" |
|
|
try: |
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
|
|
|
|
|
try: |
|
|
from utils import calculate_completeness_score |
|
|
logger.info("Imported utils.calculate_completeness_score directly") |
|
|
return calculate_completeness_score |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
|
|
|
try: |
|
|
from src.aibom_generator.utils import calculate_completeness_score |
|
|
logger.info("Imported src.aibom_generator.utils.calculate_completeness_score") |
|
|
return calculate_completeness_score |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
|
|
|
try: |
|
|
from aibom_generator.utils import calculate_completeness_score |
|
|
logger.info("Imported aibom_generator.utils.calculate_completeness_score") |
|
|
return calculate_completeness_score |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
|
|
|
logger.warning("Could not import calculate_completeness_score, using default implementation") |
|
|
return None |
|
|
except Exception as e: |
|
|
logger.error(f"Error importing utils: {str(e)}") |
|
|
return None |
|
|
|
|
|
|
|
|
calculate_completeness_score = import_utils() |
|
|
|
|
|
|
|
|
if REGISTRY_AVAILABLE: |
|
|
logger.info("✅ API fully integrated with registry system") |
|
|
else: |
|
|
logger.warning("⚠️ API using fallback mode - registry not available") |
|
|
|
|
|
|
|
|
def get_tier_points(tier): |
|
|
"""Get points for a field tier.""" |
|
|
tier_points = { |
|
|
"critical": 4.0, |
|
|
"important": 2.0, |
|
|
"supplementary": 1.0 |
|
|
} |
|
|
return tier_points.get(tier, 1.0) |
|
|
|
|
|
def create_registry_driven_fallback(): |
|
|
"""Create fallback score using registry configuration.""" |
|
|
if not REGISTRY_AVAILABLE: |
|
|
return create_hardcoded_fallback() |
|
|
|
|
|
categories = {} |
|
|
field_checklist = {} |
|
|
max_scores = {} |
|
|
|
|
|
|
|
|
for field_name, classification in FIELD_CLASSIFICATION.items(): |
|
|
category = classification["category"] |
|
|
tier = classification["tier"] |
|
|
|
|
|
|
|
|
if category not in categories: |
|
|
categories[category] = {"total": 0, "present": 0} |
|
|
max_scores[category] = 0 |
|
|
|
|
|
categories[category]["total"] += 1 |
|
|
max_scores[category] += get_tier_points(tier) |
|
|
|
|
|
|
|
|
tier_stars = {"critical": "★★★", "important": "★★", "supplementary": "★"} |
|
|
field_checklist[field_name] = f"n/a {tier_stars.get(tier, '★')}" |
|
|
|
|
|
return { |
|
|
"total_score": 0, |
|
|
"section_scores": {cat: 0 for cat in categories.keys()}, |
|
|
"max_scores": max_scores, |
|
|
"field_checklist": field_checklist, |
|
|
"category_details": categories |
|
|
} |
|
|
|
|
|
def create_hardcoded_fallback(): |
|
|
"""Fallback to original hardcoded structure when registry unavailable.""" |
|
|
return { |
|
|
"total_score": 0, |
|
|
"section_scores": { |
|
|
"required_fields": 0, |
|
|
"metadata": 0, |
|
|
"component_basic": 0, |
|
|
"component_model_card": 0, |
|
|
"external_references": 0 |
|
|
}, |
|
|
"max_scores": { |
|
|
"required_fields": 20, |
|
|
"metadata": 20, |
|
|
"component_basic": 20, |
|
|
"component_model_card": 30, |
|
|
"external_references": 10 |
|
|
}, |
|
|
"field_checklist": { |
|
|
"bomFormat": "n/a ★★★", |
|
|
"specVersion": "n/a ★★★", |
|
|
"serialNumber": "n/a ★★★", |
|
|
"version": "n/a ★★★", |
|
|
"name": "n/a ★★★", |
|
|
"downloadLocation": "n/a ★★★" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
def create_comprehensive_completeness_score(aibom=None): |
|
|
""" |
|
|
Create a comprehensive completeness_score object with all required attributes. |
|
|
Uses registry-driven field classification when available. |
|
|
""" |
|
|
|
|
|
if calculate_completeness_score and aibom: |
|
|
try: |
|
|
return calculate_completeness_score(aibom, validate=True, use_best_practices=True) |
|
|
except Exception as e: |
|
|
logger.error(f"Error calculating completeness score: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
if REGISTRY_AVAILABLE: |
|
|
logger.info("Using registry-driven completeness score fallback") |
|
|
return create_registry_driven_fallback() |
|
|
else: |
|
|
logger.warning("Using hardcoded completeness score fallback") |
|
|
return create_hardcoded_fallback() |
|
|
|
|
|
|
|
|
@app.post("/generate", response_class=HTMLResponse) |
|
|
async def generate_form( |
|
|
request: Request, |
|
|
model_id: str = Form(...), |
|
|
include_inference: bool = Form(False), |
|
|
use_best_practices: bool = Form(True), |
|
|
g_recaptcha_response: Optional[str] = Form(None) |
|
|
): |
|
|
|
|
|
form_data = await request.form() |
|
|
logger.info(f"All form data: {dict(form_data)}") |
|
|
|
|
|
|
|
|
if not verify_recaptcha(g_recaptcha_response): |
|
|
return templates.TemplateResponse( |
|
|
"error.html", { |
|
|
"request": request, |
|
|
"error": "Security verification failed. Please try again.", |
|
|
"sbom_count": get_sbom_count() |
|
|
} |
|
|
) |
|
|
|
|
|
sbom_count = get_sbom_count() |
|
|
|
|
|
|
|
|
sanitized_model_id = html.escape(model_id) |
|
|
|
|
|
|
|
|
if not is_valid_hf_input(sanitized_model_id): |
|
|
error_message = "Invalid input format. Please provide a valid Hugging Face model ID (e.g., 'owner/model') or a full model URL (e.g., 'https://huggingface.co/owner/model') ." |
|
|
logger.warning(f"Invalid model input format received: {model_id}") |
|
|
|
|
|
return templates.TemplateResponse( |
|
|
"error.html", {"request": request, "error": error_message, "sbom_count": sbom_count, "model_id": sanitized_model_id} |
|
|
) |
|
|
|
|
|
|
|
|
normalized_model_id = _normalise_model_id(sanitized_model_id) |
|
|
|
|
|
|
|
|
try: |
|
|
hf_api = HfApi() |
|
|
logger.info(f"Attempting to fetch model info for: {normalized_model_id}") |
|
|
model_info = hf_api.model_info(normalized_model_id) |
|
|
logger.info(f"Successfully fetched model info for: {normalized_model_id}") |
|
|
except RepositoryNotFoundError: |
|
|
error_message = f"Error: The provided ID \"{normalized_model_id}\" could not be found on Hugging Face or does not correspond to a model repository." |
|
|
logger.warning(f"Repository not found for ID: {normalized_model_id}") |
|
|
return templates.TemplateResponse( |
|
|
"error.html", {"request": request, "error": error_message, "sbom_count": sbom_count, "model_id": normalized_model_id} |
|
|
) |
|
|
except Exception as api_err: |
|
|
error_message = f"Error verifying model ID with Hugging Face API: {str(api_err)}" |
|
|
logger.error(f"HF API error for {normalized_model_id}: {str(api_err)}") |
|
|
return templates.TemplateResponse( |
|
|
"error.html", {"request": request, "error": error_message, "sbom_count": sbom_count, "model_id": normalized_model_id} |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
generator = None |
|
|
try: |
|
|
from src.aibom_generator.generator import AIBOMGenerator |
|
|
generator = AIBOMGenerator() |
|
|
except ImportError: |
|
|
try: |
|
|
from aibom_generator.generator import AIBOMGenerator |
|
|
generator = AIBOMGenerator() |
|
|
except ImportError: |
|
|
try: |
|
|
from generator import AIBOMGenerator |
|
|
generator = AIBOMGenerator() |
|
|
except ImportError: |
|
|
logger.error("Could not import AIBOMGenerator from any known location") |
|
|
raise ImportError("Could not import AIBOMGenerator from any known location") |
|
|
|
|
|
|
|
|
aibom = generator.generate_aibom( |
|
|
model_id=sanitized_model_id, |
|
|
include_inference=include_inference, |
|
|
use_best_practices=use_best_practices |
|
|
) |
|
|
enhancement_report = generator.get_enhancement_report() |
|
|
|
|
|
|
|
|
|
|
|
filename = f"{normalized_model_id.replace('/', '_')}_ai_sbom.json" |
|
|
filepath = os.path.join(OUTPUT_DIR, filename) |
|
|
|
|
|
with open(filepath, "w") as f: |
|
|
json.dump(aibom, f, indent=2) |
|
|
|
|
|
|
|
|
log_sbom_generation(sanitized_model_id) |
|
|
sbom_count = get_sbom_count() |
|
|
|
|
|
|
|
|
download_url = f"/output/{filename}" |
|
|
|
|
|
|
|
|
download_script = f""" |
|
|
<script> |
|
|
function downloadJSON() {{ |
|
|
const a = document.createElement('a'); |
|
|
a.href = '{download_url}'; |
|
|
a.download = '{filename}'; |
|
|
document.body.appendChild(a); |
|
|
a.click(); |
|
|
document.body.removeChild(a); |
|
|
}} |
|
|
|
|
|
function switchTab(tabId) {{ |
|
|
// Hide all tabs |
|
|
document.querySelectorAll('.tab-content').forEach(tab => {{ |
|
|
tab.classList.remove('active'); |
|
|
}}); |
|
|
|
|
|
// Deactivate all tab buttons |
|
|
document.querySelectorAll('.aibom-tab').forEach(button => {{ |
|
|
button.classList.remove('active'); |
|
|
}}); |
|
|
|
|
|
// Show the selected tab |
|
|
document.getElementById(tabId).classList.add('active'); |
|
|
|
|
|
// Activate the clicked button |
|
|
event.currentTarget.classList.add('active'); |
|
|
}} |
|
|
|
|
|
function toggleCollapsible(element) {{ |
|
|
element.classList.toggle('active'); |
|
|
var content = element.nextElementSibling; |
|
|
if (content.style.maxHeight) {{ |
|
|
content.style.maxHeight = null; |
|
|
content.classList.remove('active'); |
|
|
}} else {{ |
|
|
content.style.maxHeight = content.scrollHeight + "px"; |
|
|
content.classList.add('active'); |
|
|
}} |
|
|
}} |
|
|
</script> |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
completeness_score = None |
|
|
if hasattr(generator, 'get_completeness_score'): |
|
|
try: |
|
|
completeness_score = generator.get_completeness_score(sanitized_model_id) |
|
|
logger.info("Successfully retrieved completeness_score from generator") |
|
|
except Exception as e: |
|
|
logger.error(f"Completeness score error from generator: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
if completeness_score is None or not isinstance(completeness_score, dict) or 'field_checklist' not in completeness_score: |
|
|
logger.info("Using comprehensive completeness_score with field_checklist") |
|
|
completeness_score = create_comprehensive_completeness_score(aibom) |
|
|
|
|
|
|
|
|
if enhancement_report is None: |
|
|
enhancement_report = { |
|
|
"ai_enhanced": False, |
|
|
"ai_model": None, |
|
|
"original_score": {"total_score": 0, "completeness_score": 0}, |
|
|
"final_score": {"total_score": 0, "completeness_score": 0}, |
|
|
"improvement": 0 |
|
|
} |
|
|
else: |
|
|
|
|
|
if "original_score" not in enhancement_report or enhancement_report["original_score"] is None: |
|
|
enhancement_report["original_score"] = {"total_score": 0, "completeness_score": 0} |
|
|
elif "completeness_score" not in enhancement_report["original_score"]: |
|
|
enhancement_report["original_score"]["completeness_score"] = enhancement_report["original_score"].get("total_score", 0) |
|
|
|
|
|
|
|
|
if "final_score" not in enhancement_report or enhancement_report["final_score"] is None: |
|
|
enhancement_report["final_score"] = {"total_score": 0, "completeness_score": 0} |
|
|
elif "completeness_score" not in enhancement_report["final_score"]: |
|
|
enhancement_report["final_score"]["completeness_score"] = enhancement_report["final_score"].get("total_score", 0) |
|
|
|
|
|
|
|
|
display_names = { |
|
|
"required_fields": "Required Fields", |
|
|
"metadata": "Metadata", |
|
|
"component_basic": "Component Basic Info", |
|
|
"component_model_card": "Model Card", |
|
|
"external_references": "External References" |
|
|
} |
|
|
|
|
|
tooltips = { |
|
|
"required_fields": "Basic required fields for a valid AIBOM", |
|
|
"metadata": "Information about the AIBOM itself", |
|
|
"component_basic": "Basic information about the AI model component", |
|
|
"component_model_card": "Detailed model card information", |
|
|
"external_references": "Links to external resources" |
|
|
} |
|
|
|
|
|
weights = { |
|
|
"required_fields": 20, |
|
|
"metadata": 20, |
|
|
"component_basic": 20, |
|
|
"component_model_card": 30, |
|
|
"external_references": 10 |
|
|
} |
|
|
|
|
|
|
|
|
print("DEBUG: Checking completeness_score for undefined values:") |
|
|
if completeness_score and 'section_scores' in completeness_score: |
|
|
for key, value in completeness_score['section_scores'].items(): |
|
|
print(f" {key}: {value} (type: {type(value)})") |
|
|
else: |
|
|
print(" No section_scores found in completeness_score") |
|
|
|
|
|
|
|
|
print("DEBUG: Template data check:") |
|
|
if completeness_score: |
|
|
print(f" completeness_score keys: {list(completeness_score.keys())}") |
|
|
if 'category_details' in completeness_score: |
|
|
print(f" category_details exists: {list(completeness_score['category_details'].keys())}") |
|
|
|
|
|
if REGISTRY_AVAILABLE: |
|
|
categories = set(classification["category"] for classification in FIELD_CLASSIFICATION.values()) |
|
|
else: |
|
|
categories = ['required_fields', 'metadata', 'component_basic', 'component_model_card', 'external_references'] |
|
|
|
|
|
for category in categories: |
|
|
if category in completeness_score['category_details']: |
|
|
details = completeness_score['category_details'][category] |
|
|
print(f" {category}: present={details.get('present_fields')}, total={details.get('total_fields')}, percentage={details.get('percentage')}") |
|
|
else: |
|
|
print(f" {category}: MISSING from category_details") |
|
|
else: |
|
|
print(" category_details: NOT FOUND in completeness_score!") |
|
|
else: |
|
|
print(" completeness_score: IS NONE!") |
|
|
|
|
|
|
|
|
return templates.TemplateResponse( |
|
|
"result.html", |
|
|
{ |
|
|
"request": request, |
|
|
"model_id": normalized_model_id, |
|
|
"aibom": aibom, |
|
|
"enhancement_report": enhancement_report, |
|
|
"completeness_score": completeness_score, |
|
|
"download_url": download_url, |
|
|
"download_script": download_script, |
|
|
"display_names": display_names, |
|
|
"tooltips": tooltips, |
|
|
"weights": weights, |
|
|
"sbom_count": sbom_count, |
|
|
"display_names": display_names, |
|
|
"tooltips": tooltips, |
|
|
"weights": weights |
|
|
} |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error generating AI SBOM: {str(e)}") |
|
|
sbom_count = get_sbom_count() |
|
|
|
|
|
return templates.TemplateResponse( |
|
|
"error.html", {"request": request, "error": str(e), "sbom_count": sbom_count, "model_id": normalized_model_id} |
|
|
) |
|
|
|
|
|
@app.get("/download/{filename}") |
|
|
async def download_file(filename: str): |
|
|
""" |
|
|
Download a generated AIBOM file. |
|
|
|
|
|
This endpoint serves the generated AIBOM JSON files for download. |
|
|
""" |
|
|
file_path = os.path.join(OUTPUT_DIR, filename) |
|
|
if not os.path.exists(file_path): |
|
|
raise HTTPException(status_code=404, detail="File not found") |
|
|
|
|
|
return FileResponse( |
|
|
file_path, |
|
|
media_type="application/json", |
|
|
filename=filename |
|
|
) |
|
|
|
|
|
|
|
|
class GenerateRequest(BaseModel): |
|
|
model_id: str |
|
|
include_inference: bool = True |
|
|
use_best_practices: bool = True |
|
|
hf_token: Optional[str] = None |
|
|
|
|
|
@app.post("/api/generate") |
|
|
async def api_generate_aibom(request: GenerateRequest): |
|
|
""" |
|
|
Generate an AI SBOM for a specified Hugging Face model. |
|
|
|
|
|
This endpoint accepts JSON input and returns JSON output. |
|
|
""" |
|
|
try: |
|
|
|
|
|
sanitized_model_id = html.escape(request.model_id) |
|
|
if not is_valid_hf_input(sanitized_model_id): |
|
|
raise HTTPException(status_code=400, detail="Invalid model ID format") |
|
|
|
|
|
normalized_model_id = _normalise_model_id(sanitized_model_id) |
|
|
|
|
|
|
|
|
try: |
|
|
hf_api = HfApi() |
|
|
model_info = hf_api.model_info(normalized_model_id) |
|
|
except RepositoryNotFoundError: |
|
|
raise HTTPException(status_code=404, detail=f"Model {normalized_model_id} not found on Hugging Face") |
|
|
except Exception as api_err: |
|
|
raise HTTPException(status_code=500, detail=f"Error verifying model: {str(api_err)}") |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
generator = None |
|
|
try: |
|
|
from src.aibom_generator.generator import AIBOMGenerator |
|
|
generator = AIBOMGenerator() |
|
|
except ImportError: |
|
|
try: |
|
|
from aibom_generator.generator import AIBOMGenerator |
|
|
generator = AIBOMGenerator() |
|
|
except ImportError: |
|
|
try: |
|
|
from generator import AIBOMGenerator |
|
|
generator = AIBOMGenerator() |
|
|
except ImportError: |
|
|
raise HTTPException(status_code=500, detail="Could not import AIBOMGenerator") |
|
|
|
|
|
aibom = generator.generate_aibom( |
|
|
model_id=sanitized_model_id, |
|
|
include_inference=request.include_inference, |
|
|
use_best_practices=request.use_best_practices |
|
|
) |
|
|
enhancement_report = generator.get_enhancement_report() |
|
|
|
|
|
|
|
|
filename = f"{normalized_model_id.replace('/', '_')}_ai_sbom.json" |
|
|
filepath = os.path.join(OUTPUT_DIR, filename) |
|
|
with open(filepath, "w") as f: |
|
|
json.dump(aibom, f, indent=2) |
|
|
|
|
|
|
|
|
log_sbom_generation(sanitized_model_id) |
|
|
|
|
|
|
|
|
return { |
|
|
"aibom": aibom, |
|
|
"model_id": normalized_model_id, |
|
|
"generated_at": datetime.utcnow().isoformat() + "Z", |
|
|
"request_id": str(uuid.uuid4()), |
|
|
"download_url": f"/output/{filename}" |
|
|
} |
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error generating AI SBOM: {str(e)}") |
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error generating AI SBOM: {str(e)}") |
|
|
|
|
|
@app.post("/api/generate-with-report") |
|
|
async def api_generate_with_report(request: GenerateRequest): |
|
|
""" |
|
|
Generate an AI SBOM with a completeness report. |
|
|
This endpoint accepts JSON input and returns JSON output with completeness score. |
|
|
""" |
|
|
try: |
|
|
|
|
|
sanitized_model_id = html.escape(request.model_id) |
|
|
if not is_valid_hf_input(sanitized_model_id): |
|
|
raise HTTPException(status_code=400, detail="Invalid model ID format") |
|
|
|
|
|
normalized_model_id = _normalise_model_id(sanitized_model_id) |
|
|
|
|
|
|
|
|
try: |
|
|
hf_api = HfApi() |
|
|
model_info = hf_api.model_info(normalized_model_id) |
|
|
except RepositoryNotFoundError: |
|
|
raise HTTPException(status_code=404, detail=f"Model {normalized_model_id} not found on Hugging Face") |
|
|
except Exception as api_err: |
|
|
raise HTTPException(status_code=500, detail=f"Error verifying model: {str(api_err)}") |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
generator = None |
|
|
try: |
|
|
from src.aibom_generator.generator import AIBOMGenerator |
|
|
generator = AIBOMGenerator() |
|
|
except ImportError: |
|
|
try: |
|
|
from aibom_generator.generator import AIBOMGenerator |
|
|
generator = AIBOMGenerator() |
|
|
except ImportError: |
|
|
try: |
|
|
from generator import AIBOMGenerator |
|
|
generator = AIBOMGenerator() |
|
|
except ImportError: |
|
|
raise HTTPException(status_code=500, detail="Could not import AIBOMGenerator") |
|
|
|
|
|
aibom = generator.generate_aibom( |
|
|
model_id=sanitized_model_id, |
|
|
include_inference=request.include_inference, |
|
|
use_best_practices=request.use_best_practices |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
completeness_score = calculate_completeness_score(aibom, validate=True, use_best_practices=True) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed completeness scoring for {normalized_model_id}: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=f"Error calculating score: {str(e)}") |
|
|
|
|
|
|
|
|
for section, score in completeness_score["section_scores"].items(): |
|
|
if isinstance(score, float) and not score.is_integer(): |
|
|
completeness_score["section_scores"][section] = round(score, 1) |
|
|
|
|
|
|
|
|
if "field_checklist" in completeness_score: |
|
|
machine_parseable_checklist = {} |
|
|
for field, value in completeness_score["field_checklist"].items(): |
|
|
|
|
|
present = "present" if "✔" in value else "missing" |
|
|
|
|
|
|
|
|
importance = completeness_score["field_tiers"].get(field, "unknown") |
|
|
|
|
|
|
|
|
machine_parseable_checklist[field] = { |
|
|
"status": present, |
|
|
"importance": importance |
|
|
} |
|
|
|
|
|
|
|
|
completeness_score["field_checklist"] = machine_parseable_checklist |
|
|
|
|
|
|
|
|
completeness_score.pop("field_tiers", None) |
|
|
|
|
|
|
|
|
filename = f"{normalized_model_id.replace('/', '_')}_ai_sbom.json" |
|
|
filepath = os.path.join(OUTPUT_DIR, filename) |
|
|
with open(filepath, "w") as f: |
|
|
json.dump(aibom, f, indent=2) |
|
|
|
|
|
|
|
|
log_sbom_generation(sanitized_model_id) |
|
|
|
|
|
|
|
|
return { |
|
|
"aibom": aibom, |
|
|
"model_id": normalized_model_id, |
|
|
"generated_at": datetime.utcnow().isoformat() + "Z", |
|
|
"request_id": str(uuid.uuid4()), |
|
|
"download_url": f"/output/{filename}", |
|
|
"completeness_score": completeness_score |
|
|
} |
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error generating AI SBOM: {str(e)}") |
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error generating AI SBOM: {str(e)}") |
|
|
|
|
|
|
|
|
@app.get("/api/models/{model_id:path}/score" ) |
|
|
async def get_model_score( |
|
|
model_id: str, |
|
|
hf_token: Optional[str] = None, |
|
|
use_best_practices: bool = True |
|
|
): |
|
|
""" |
|
|
Get the completeness score for a model without generating a full AIBOM. |
|
|
""" |
|
|
try: |
|
|
|
|
|
sanitized_model_id = html.escape(model_id) |
|
|
if not is_valid_hf_input(sanitized_model_id): |
|
|
raise HTTPException(status_code=400, detail="Invalid model ID format") |
|
|
|
|
|
normalized_model_id = _normalise_model_id(sanitized_model_id) |
|
|
|
|
|
|
|
|
try: |
|
|
hf_api = HfApi(token=hf_token) |
|
|
model_info = hf_api.model_info(normalized_model_id) |
|
|
except RepositoryNotFoundError: |
|
|
raise HTTPException(status_code=404, detail=f"Model {normalized_model_id} not found on Hugging Face") |
|
|
except Exception as api_err: |
|
|
raise HTTPException(status_code=500, detail=f"Error verifying model: {str(api_err)}") |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
generator = None |
|
|
try: |
|
|
from src.aibom_generator.generator import AIBOMGenerator |
|
|
generator = AIBOMGenerator(hf_token=hf_token) |
|
|
except ImportError: |
|
|
try: |
|
|
from aibom_generator.generator import AIBOMGenerator |
|
|
generator = AIBOMGenerator(hf_token=hf_token) |
|
|
except ImportError: |
|
|
try: |
|
|
from generator import AIBOMGenerator |
|
|
generator = AIBOMGenerator(hf_token=hf_token) |
|
|
except ImportError: |
|
|
raise HTTPException(status_code=500, detail="Could not import AIBOMGenerator") |
|
|
|
|
|
|
|
|
aibom = generator.generate_aibom( |
|
|
model_id=sanitized_model_id, |
|
|
include_inference=False, |
|
|
use_best_practices=use_best_practices |
|
|
) |
|
|
|
|
|
|
|
|
score = calculate_completeness_score(aibom, validate=True, use_best_practices=use_best_practices) |
|
|
|
|
|
|
|
|
log_sbom_generation(normalized_model_id) |
|
|
|
|
|
|
|
|
for section, value in score["section_scores"].items(): |
|
|
if isinstance(value, float) and not value.is_integer(): |
|
|
score["section_scores"][section] = round(float(value), 1) if value is not None and value != "Undefined" else 0.0 |
|
|
|
|
|
|
|
|
return { |
|
|
"total_score": score["total_score"], |
|
|
"section_scores": score["section_scores"], |
|
|
"max_scores": score["max_scores"] |
|
|
} |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error calculating model score: {str(e)}") |
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
class BatchRequest(BaseModel): |
|
|
model_ids: List[str] |
|
|
include_inference: bool = True |
|
|
use_best_practices: bool = True |
|
|
hf_token: Optional[str] = None |
|
|
|
|
|
|
|
|
batch_jobs = {} |
|
|
|
|
|
@app.post("/api/batch") |
|
|
async def batch_generate(request: BatchRequest): |
|
|
""" |
|
|
Start a batch job to generate AIBOMs for multiple models. |
|
|
""" |
|
|
try: |
|
|
|
|
|
valid_model_ids = [] |
|
|
for model_id in request.model_ids: |
|
|
sanitized_id = html.escape(model_id) |
|
|
if is_valid_hf_input(sanitized_id): |
|
|
valid_model_ids.append(sanitized_id) |
|
|
else: |
|
|
logger.warning(f"Skipping invalid model ID: {model_id}") |
|
|
|
|
|
if not valid_model_ids: |
|
|
raise HTTPException(status_code=400, detail="No valid model IDs provided") |
|
|
|
|
|
|
|
|
job_id = str(uuid.uuid4()) |
|
|
created_at = datetime.utcnow() |
|
|
|
|
|
|
|
|
batch_jobs[job_id] = { |
|
|
"job_id": job_id, |
|
|
"status": "queued", |
|
|
"model_ids": valid_model_ids, |
|
|
"created_at": created_at.isoformat() + "Z", |
|
|
"completed": 0, |
|
|
"total": len(valid_model_ids), |
|
|
"results": {} |
|
|
} |
|
|
|
|
|
|
|
|
batch_jobs[job_id]["status"] = "processing" |
|
|
|
|
|
return { |
|
|
"job_id": job_id, |
|
|
"status": "queued", |
|
|
"model_ids": valid_model_ids, |
|
|
"created_at": created_at.isoformat() + "Z" |
|
|
} |
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Error creating batch job: {str(e)}") |
|
|
|
|
|
@app.get("/api/batch/{job_id}") |
|
|
async def get_batch_status(job_id: str): |
|
|
""" |
|
|
Check the status of a batch job. |
|
|
""" |
|
|
if job_id not in batch_jobs: |
|
|
raise HTTPException(status_code=404, detail=f"Batch job {job_id} not found") |
|
|
|
|
|
return batch_jobs[job_id] |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
|
|
|
if not HF_TOKEN: |
|
|
print("Warning: HF_TOKEN environment variable not set. SBOM count will show N/A and logging will be skipped.") |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|
|
|
|
|
|
@app.get("/api/registry/status") |
|
|
async def get_registry_status(): |
|
|
"""Get current registry configuration status for debugging.""" |
|
|
if REGISTRY_AVAILABLE: |
|
|
categories = {} |
|
|
for field_name, classification in FIELD_CLASSIFICATION.items(): |
|
|
category = classification["category"] |
|
|
if category not in categories: |
|
|
categories[category] = 0 |
|
|
categories[category] += 1 |
|
|
|
|
|
return { |
|
|
"registry_available": True, |
|
|
"total_fields": len(FIELD_CLASSIFICATION), |
|
|
"categories": list(categories.keys()), |
|
|
"field_count_by_category": categories, |
|
|
"registry_manager_loaded": REGISTRY_MANAGER is not None |
|
|
} |
|
|
else: |
|
|
return { |
|
|
"registry_available": False, |
|
|
"fallback_mode": True, |
|
|
"message": "Using hardcoded field definitions", |
|
|
"total_fields": 6 |
|
|
} |