Spaces:
Runtime error
Runtime error
from fastai.vision.all import * | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
import uvicorn | |
import logging | |
import tempfile | |
from pathlib import Path | |
import firebase_admin | |
from firebase_admin import credentials, firestore, storage | |
from pydantic import BaseModel | |
# Load the pre-trained model | |
learn = load_learner('model.pkl') | |
# Define categories and map them to indices | |
searches = ['formal', 'casual', 'athletic'] | |
searches = sorted(searches) # Ensure the categories are in sorted order | |
values = [i for i in range(0, len(searches))] | |
class_dict = dict(zip(searches, values)) | |
# Set up logging | |
logging.basicConfig(level=logging.DEBUG, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Initialize Firebase | |
try: | |
cred = credentials.Certificate("serviceAccountKey.json") | |
firebase_app = firebase_admin.initialize_app(cred, { | |
'storageBucket': 'future-forge-60d3f.appspot.com' | |
}) | |
db = firestore.client() | |
bucket = storage.bucket(app=firebase_app) | |
logger.info("Firebase initialized successfully") | |
except Exception as e: | |
logger.error(f"Failed to initialize Firebase: {str(e)}") | |
app = FastAPI() | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Define the input model | |
class FileProcess(BaseModel): | |
file_path: str | |
async def process_file(file_data: FileProcess): | |
logger.info(f"Processing file from Firebase Storage: {file_data.file_path}") | |
try: | |
# Get the file from Firebase Storage | |
blob = bucket.blob(file_data.file_path) | |
# Create a temporary file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file_data.file_path.split('.')[-1]}") as tmp_file: | |
blob.download_to_filename(tmp_file.name) | |
tmp_file_path = Path(tmp_file.name) | |
logger.info(f"File downloaded temporarily at: {tmp_file_path}") | |
file_type = file_data.file_path.split('.')[-1].lower() | |
try: | |
if file_type in ['jpg', 'jpeg', 'png', 'bmp']: | |
output = process_video(str(tmp_file_path)) | |
result = {"type": "image", "data": {"result": output}} | |
else: | |
raise HTTPException(status_code=400, detail="Unsupported file type") | |
logger.info(f"Processing complete. Result: {result}") | |
# Store result in Firebase | |
try: | |
doc_ref = db.collection('results').add(result) | |
return {"message": "File processed successfully", "result": result} | |
except Exception as e: | |
logger.error(f"Failed to store result in Firebase: {str(e)}") | |
return {"message": "File processed successfully, but failed to store in Firebase", "result": result, | |
"error": str(e)} | |
finally: | |
# Clean up the temporary file | |
tmp_file_path.unlink() | |
except Exception as e: | |
logger.error(f"Error processing file: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error processing file: {str(e)}") | |
def process_video(video_path): | |
# Load the image from the provided path | |
img = PILImage.create(video_path) | |
# Make the prediction | |
classification, _, probs = learn.predict(img) | |
# Convert the prediction to a confidence dictionary | |
confidences = {label: float(probs[i]) for i, label in enumerate(class_dict)} | |
# If classification is not formal, return 'informal' | |
if classification != 'formal': | |
informal_confidence = sum(confidences[label] for label in class_dict if label != 'formal') | |
return {'informal': informal_confidence} | |
else: | |
return {'formal': confidences['formal']} | |
if __name__ == "__main__": | |
logger.info("Starting the Face Emotion Recognition API") | |
uvicorn.run(app, host="0.0.0.0", port=8000) |