Spaces:
Runtime error
Runtime error
from fastai.vision.all import * | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
import uvicorn | |
import logging | |
import tempfile | |
from pathlib import Path | |
import firebase_admin | |
from firebase_admin import credentials, firestore, storage | |
from pydantic import BaseModel | |
import torch | |
from transformers import AutoImageProcessor, AutoModelForObjectDetection | |
from PIL import Image, ImageDraw, ImageFont | |
import cv2 | |
import random | |
# Load model and processor | |
processor = AutoImageProcessor.from_pretrained("valentinafeve/yolos-fashionpedia") | |
model = AutoModelForObjectDetection.from_pretrained("valentinafeve/yolos-fashionpedia") | |
# Fashionpedia categories | |
FASHION_CATEGORIES = [ | |
'shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 'skirt', | |
'coat', 'dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove', | |
'watch', 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'bag, wallet', 'scarf', 'umbrella', 'hood', | |
'collar', 'lapel', 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow', | |
'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel' | |
] | |
def detect_fashion(image): | |
inputs = processor(images=image, return_tensors="pt") | |
outputs = model(**inputs) | |
# Convert outputs (bounding boxes and class logits) to COCO API | |
target_sizes = torch.tensor([image.size[::-1]]) | |
results = processor.post_process_object_detection(outputs, threshold=0.1, target_sizes=target_sizes)[0] | |
detected_items = [] | |
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): | |
if score > 0.5: # Adjust this threshold as needed | |
detected_items.append((FASHION_CATEGORIES[label], score.item(), box.tolist())) | |
return detected_items | |
def check_dress_code(detected_items): | |
formal_workplace_attire = { | |
"shirt, blouse", "jacket", "tie", "coat", "sweater", "cardigan", "coat" | |
} | |
return any(item[0] in formal_workplace_attire for item in detected_items) | |
async def process_file(file_data: FileProcess): | |
logger.info(f"Processing file from Firebase Storage: {file_data.file_path}") | |
try: | |
# Get the file from Firebase Storage | |
blob = bucket.blob(file_data.file_path) | |
# Create a temporary file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file_data.file_path.split('.')[-1]}") as tmp_file: | |
blob.download_to_filename(tmp_file.name) | |
tmp_file_path = Path(tmp_file.name) | |
logger.info(f"File downloaded temporarily at: {tmp_file_path}") | |
file_type = file_data.file_path.split('.')[-1].lower() | |
try: | |
if file_type in ['mp4', 'avi', 'mov', 'wmv']: | |
output,testing = process_video(str(tmp_file_path)) | |
result = {"type": "video", "data": {"result": output}} | |
else: | |
raise HTTPException(status_code=400, detail="Unsupported file type") | |
logger.info(f"Processing complete. Result: {result}") | |
# Store result in Firebase | |
try: | |
doc_ref = db.collection('results').add(result) | |
return {"message": "File processed successfully", "result": result} | |
except Exception as e: | |
logger.error(f"Failed to store result in Firebase: {str(e)}") | |
return {"message": "File processed successfully, but failed to store in Firebase", "result": result, | |
"error": str(e)} | |
finally: | |
# Clean up the temporary file | |
tmp_file_path.unlink() | |
except Exception as e: | |
logger.error(f"Error processing file: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error processing file: {str(e)}") | |
def process_video(video_path,num_frames=10): | |
cap = cv2.VideoCapture(video_path) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
frame_indices = sorted(random.sample(range(total_frames), min(num_frames, total_frames))) | |
compliance_results = [] | |
for frame_index in frame_indices: | |
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) | |
ret, frame = cap.read() | |
if ret: | |
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
detected_items = detect_fashion(image) | |
is_compliant = check_dress_code(detected_items) | |
compliance_results.append(is_compliant) | |
cap.release() | |
average_compliance = sum(compliance_results) / len(compliance_results) | |
return average_compliance, compliance_results | |
if __name__ == "__main__": | |
logger.info("Starting the Face Emotion Recognition API") | |
uvicorn.run(app, host="0.0.0.0", port=8000) |