Spaces:
Build error
Build error
from fastapi import FastAPI, File, UploadFile, HTTPException | |
from fastapi.responses import JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
import torch | |
from torchvision import transforms, models | |
import torch.nn.nn as nn | |
from PIL import Image | |
import io | |
import pandas as pd | |
import sqlite3 | |
import os | |
from typing import cast | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI(title="Pet Breed Prediction API", version="1.0.0") | |
# More specific CORS configuration for Docker deployment | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Allow all origins for Hugging Face deployment | |
allow_credentials=False, # Must be False if using "*" | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Add OPTIONS handler for preflight requests | |
async def options_predict(): | |
return JSONResponse( | |
content={"message": "OK"}, | |
headers={ | |
"Access-Control-Allow-Origin": "*", | |
"Access-Control-Allow-Methods": "POST, OPTIONS", | |
"Access-Control-Allow-Headers": "*", | |
} | |
) | |
# --- Configuration and Initialization --- | |
# Updated paths for Docker container structure | |
MODEL_PATH = "/app/Backend/best_model.pth" | |
DATABASE_PATH = "/app/Backend/petdatabase.db" | |
# Fallback paths for local development | |
if not os.path.exists(MODEL_PATH): | |
MODEL_PATH = "./best_model.pth" | |
if not os.path.exists(DATABASE_PATH): | |
DATABASE_PATH = "./petdatabase.db" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
logger.info(f"Using device: {device}") | |
logger.info(f"Model path: {MODEL_PATH}") | |
logger.info(f"Database path: {DATABASE_PATH}") | |
NUM_CLASSES = 55 | |
model = None | |
class_names = [ | |
'Abyssinian', 'Alaskan Malamute', 'American Bobtail', 'American Shorthair', 'American bulldog', | |
'American Pit Bull Terrier', 'Basset Hound', 'Beagle', 'Bengal', 'Birman', 'Bombay', 'Boxer', | |
'British Shorthair', 'Bulldog', 'Calico', 'Chihuahua', 'Dachshund', 'Egyptian Mau', | |
'English Cocker Spaniel', 'English Setter', 'German Shepherd', 'German Shorthaired Pointer', | |
'Golden Retreiver', 'Great Pyrenees', 'Havanese', 'Husky', 'Japanese Chin', 'Keeshond', | |
'Labrador Retriever', 'Leonberger', 'Maine Coon', 'Miniature Pinscher', 'Munchkin', | |
'Newfoundland', 'Norwegian Forest', 'Ocicat', 'Persian', 'Pomeranian', 'Poodle', 'Pug', | |
'Ragdoll', 'Rottweiler', 'Russian Blue', 'Saint Bernard', 'Samoyed', 'Scottish Fold', | |
'Scottish Terrier', 'Shiba Inu', 'Siamese', 'Sphynx', 'Staffordshire Bull Terrier', | |
'Tortoiseshell', 'Tuxedo', 'Wheaten Terrier', 'Yorkshire Terrier' | |
] | |
breed_descriptions_df = pd.DataFrame() | |
# Image transformations (must match your training) | |
transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
def load_model(): | |
global model | |
try: | |
logger.info(f"Loading model from {MODEL_PATH}") | |
# Check if model file exists | |
if not os.path.exists(MODEL_PATH): | |
logger.error(f"Model file not found at {MODEL_PATH}") | |
raise FileNotFoundError(f"Model file not found at {MODEL_PATH}") | |
model = models.efficientnet_b0(pretrained=False) | |
num_ftrs: int = model.classifier[1].in_features # type: ignore | |
model.classifier = nn.Sequential( | |
nn.Dropout(0.3), | |
nn.Linear(num_ftrs, 256), | |
nn.ReLU(), | |
nn.Dropout(0.3), | |
nn.Linear(256, NUM_CLASSES) | |
) | |
# Load model with proper error handling | |
try: | |
state_dict = torch.load(MODEL_PATH, map_location=device, weights_only=True) | |
except TypeError: | |
# Fallback for older PyTorch versions | |
state_dict = torch.load(MODEL_PATH, map_location=device) | |
model.load_state_dict(state_dict) | |
model.to(device) | |
model.eval() | |
logger.info(f"Model loaded successfully from {MODEL_PATH}") | |
except FileNotFoundError as e: | |
logger.error(f"Model file not found: {e}") | |
raise RuntimeError(f"Model file not found: {e}") | |
except Exception as e: | |
logger.error(f"Error loading model: {e}") | |
raise RuntimeError(f"Error loading model: {e}") | |
def load_breed_descriptions(): | |
global breed_descriptions_df | |
try: | |
logger.info(f"Loading breed descriptions from database: {DATABASE_PATH}") | |
if not os.path.exists(DATABASE_PATH): | |
logger.error(f"Database file not found at {DATABASE_PATH}") | |
breed_descriptions_df = pd.DataFrame() | |
return | |
conn = sqlite3.connect(DATABASE_PATH) | |
# Load catbreeds and dogbreeds tables and combine | |
catbreeds_df = pd.DataFrame() | |
dogbreeds_df = pd.DataFrame() | |
try: | |
catbreeds_df = pd.read_sql_query("SELECT * FROM catbreeds", conn) | |
logger.info(f"Loaded {len(catbreeds_df)} records from catbreeds") | |
except Exception as e: | |
logger.warning(f"Failed to load catbreeds table: {e}") | |
try: | |
dogbreeds_df = pd.read_sql_query("SELECT * FROM dogbreeds", conn) | |
logger.info(f"Loaded {len(dogbreeds_df)} records from dogbreeds") | |
except Exception as e: | |
logger.warning(f"Failed to load dogbreeds table: {e}") | |
# Combine dataframes vertically, adding a column to identify type | |
if not catbreeds_df.empty: | |
catbreeds_df['type'] = 'cat' | |
if not dogbreeds_df.empty: | |
dogbreeds_df['type'] = 'dog' | |
breed_descriptions_df = pd.concat([catbreeds_df, dogbreeds_df], ignore_index=True, sort=False) | |
conn.close() | |
if breed_descriptions_df.empty: | |
logger.warning("No breed descriptions loaded from database") | |
else: | |
logger.info(f"Total breed descriptions loaded: {len(breed_descriptions_df)}") | |
except Exception as e: | |
logger.error(f"Error loading breed descriptions from database: {e}") | |
breed_descriptions_df = pd.DataFrame() | |
async def startup_event(): | |
logger.info("Starting up Pet Breed Prediction API...") | |
logger.info(f"Working directory: {os.getcwd()}") | |
logger.info(f"Available files: {os.listdir('.')}") | |
try: | |
load_model() | |
load_breed_descriptions() | |
logger.info("Startup complete!") | |
except Exception as e: | |
logger.error(f"Startup failed: {e}") | |
# Don't raise here to allow the server to start even if model/db loading fails | |
def get_breed_description(breed_name: str) -> dict: | |
if breed_descriptions_df.empty: | |
return {} | |
breed_col = 'breed_name' | |
row = breed_descriptions_df[ | |
breed_descriptions_df[breed_col].astype(str).str.strip().str.lower() == breed_name.strip().lower() | |
] | |
if not row.empty: | |
# Return the entire row as a dictionary (first match) | |
result = row.iloc[0].to_dict() | |
# Convert numpy types to Python types for JSON serialization | |
return {k: (v.item() if hasattr(v, 'item') else v) for k, v in result.items()} | |
return {} | |
async def predict(file: UploadFile = File(...)): | |
logger.info(f"Received prediction request for file: {file.filename}") | |
# Check if model is loaded | |
if model is None: | |
logger.error("Model not loaded") | |
raise HTTPException(status_code=500, detail="Model not loaded. Please check server logs and try again.") | |
# Validate content type | |
if not file.content_type or not file.content_type.startswith("image/"): | |
logger.warning(f"Invalid content type: {file.content_type}") | |
raise HTTPException(status_code=400, detail="File must be an image.") | |
# Read file contents | |
try: | |
contents = await file.read() | |
logger.info(f"File size: {len(contents)} bytes") | |
except Exception as e: | |
logger.error(f"Error reading file: {e}") | |
raise HTTPException(status_code=400, detail="Error reading uploaded file.") | |
# Process image | |
try: | |
image = Image.open(io.BytesIO(contents)).convert("RGB") | |
logger.info(f"Image size: {image.size}") | |
except Exception as e: | |
logger.error(f"Error processing image: {e}") | |
raise HTTPException(status_code=400, detail="Invalid image file. Please provide a valid image (e.g., JPG, PNG).") | |
# Make prediction | |
try: | |
input_tensor = cast(torch.Tensor, transform(image)).unsqueeze(0).to(device) | |
logger.info(f"Input tensor shape: {input_tensor.shape}") | |
with torch.no_grad(): | |
outputs = model(input_tensor) | |
probabilities = torch.nn.functional.softmax(outputs, dim=1).squeeze(dim=0) | |
top3_prob, top3_idx = torch.topk(probabilities, 3) | |
results = [] | |
for prob, idx in zip(top3_prob, top3_idx): | |
breed = class_names[int(idx.item())] | |
breed_info = get_breed_description(breed) | |
# Ensure all values are JSON serializable | |
result = { | |
"breed": breed, | |
"confidence": f"{prob.item()*100:.2f}%", | |
"breed_name": breed_info.get('breed_name', breed), | |
"height_male_min": breed_info.get('height_male_min'), | |
"height_male_max": breed_info.get('height_male_max'), | |
"height_female_min": breed_info.get('height_female_min'), | |
"height_female_max": breed_info.get('height_female_max'), | |
"weight_male_min": breed_info.get('weight_male_min'), | |
"weight_male_max": breed_info.get('weight_male_max'), | |
"weight_female_min": breed_info.get('weight_female_min'), | |
"weight_female_max": breed_info.get('weight_female_max'), | |
"life_expectancy_min": breed_info.get('life_expectancy_min'), | |
"life_expectancy_max": breed_info.get('life_expectancy_max'), | |
"characteristics": breed_info.get('characteristics'), | |
"exercise_needs": breed_info.get('exercise_needs'), | |
"grooming_requirements": breed_info.get('grooming_requirements'), | |
"health_considerations": breed_info.get('health_considerations'), | |
"diet_nutrition": breed_info.get('diet_nutrition') | |
} | |
results.append(result) | |
confidence_message = None | |
if top3_prob[0].item() < 0.70: | |
confidence_message = "**Message: Make sure the image is clear! Low confidence detected.**" | |
logger.info(f"Prediction successful. Top breed: {results[0]['breed']} ({results[0]['confidence']})") | |
return JSONResponse(content={ | |
"predictions": results, | |
"confidence_message": confidence_message | |
}) | |
except Exception as e: | |
logger.error(f"Error during prediction: {e}") | |
raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}") | |
async def root(): | |
return { | |
"message": "Welcome to the Pet Breed Prediction API!", | |
"status": "running", | |
"model_loaded": model is not None, | |
"device": str(device), | |
"version": "1.0.0", | |
"docker_deployment": True | |
} | |
async def health_check(): | |
return { | |
"status": "healthy", | |
"model_loaded": model is not None, | |
"device": str(device), | |
"breed_descriptions_loaded": not breed_descriptions_df.empty, | |
"database_path": DATABASE_PATH, | |
"database_exists": os.path.exists(DATABASE_PATH), | |
"model_path": MODEL_PATH, | |
"model_exists": os.path.exists(MODEL_PATH), | |
"working_directory": os.getcwd(), | |
"pytorch_version": torch.__version__ | |
} | |
# New endpoint to check database structure | |
async def database_info(): | |
try: | |
if not os.path.exists(DATABASE_PATH): | |
return {"error": f"Database file not found at {DATABASE_PATH}"} | |
conn = sqlite3.connect(DATABASE_PATH) | |
# Get table names | |
tables_query = "SELECT name FROM sqlite_master WHERE type='table'" | |
tables_df = pd.read_sql_query(tables_query, conn) | |
tables = tables_df['name'].tolist() | |
# Get column info for each table | |
table_info = {} | |
for table in tables: | |
try: | |
columns_query = f"PRAGMA table_info({table})" | |
columns_df = pd.read_sql_query(columns_query, conn) | |
table_info[table] = columns_df.to_dict('records') | |
except Exception as e: | |
table_info[table] = f"Error: {e}" | |
conn.close() | |
return { | |
"database_path": DATABASE_PATH, | |
"tables": tables, | |
"table_info": table_info, | |
"breed_descriptions_loaded": not breed_descriptions_df.empty, | |
"total_records": len(breed_descriptions_df) if not breed_descriptions_df.empty else 0 | |
} | |
except Exception as e: | |
return {"error": f"Error reading database: {e}"} | |
# New endpoint for debugging file system | |
async def debug_files(): | |
"""Debug endpoint to check file system structure""" | |
try: | |
current_dir = os.getcwd() | |
files_info = { | |
"current_directory": current_dir, | |
"files_in_current_dir": os.listdir(current_dir) if os.path.exists(current_dir) else [], | |
} | |
# Check specific directories | |
check_dirs = ["/app", "/app/Backend", "/app/Frontend", "/app/Asset"] | |
for dir_path in check_dirs: | |
if os.path.exists(dir_path): | |
files_info[f"files_in_{dir_path.replace('/', '_')}"] = os.listdir(dir_path) | |
else: | |
files_info[f"files_in_{dir_path.replace('/', '_')}"] = "Directory not found" | |
return files_info | |
except Exception as e: | |
return {"error": f"Error checking files: {e}"} | |
if __name__ == "__main__": | |
import uvicorn | |
logger.info("Starting server...") | |
# Updated server configuration for Docker deployment | |
uvicorn.run( | |
"main:app", | |
host="127.0.0.1", # Bind to localhost since nginx will proxy | |
port=8000, # Use port 8000 for backend (nginx listens on 7860) | |
reload=False, # Disable reload in production | |
log_level="info", | |
access_log=True | |
) |