AhmedIbrahim007's picture
Update app.py
275f227 verified
from fastai.vision.all import *
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import logging
import tempfile
from pathlib import Path
import firebase_admin
from firebase_admin import credentials, firestore, storage
from pydantic import BaseModel
import torch
from transformers import AutoImageProcessor, AutoModelForObjectDetection
from PIL import Image, ImageDraw, ImageFont
import cv2
import random
# Load model and processor
processor = AutoImageProcessor.from_pretrained("valentinafeve/yolos-fashionpedia")
model = AutoModelForObjectDetection.from_pretrained("valentinafeve/yolos-fashionpedia")
# Fashionpedia categories
FASHION_CATEGORIES = [
'shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 'skirt',
'coat', 'dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove',
'watch', 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'bag, wallet', 'scarf', 'umbrella', 'hood',
'collar', 'lapel', 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow',
'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel'
]
def detect_fashion(image):
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
# Convert outputs (bounding boxes and class logits) to COCO API
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, threshold=0.1, target_sizes=target_sizes)[0]
detected_items = []
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
if score > 0.5: # Adjust this threshold as needed
detected_items.append((FASHION_CATEGORIES[label], score.item(), box.tolist()))
return detected_items
def check_dress_code(detected_items):
formal_workplace_attire = {
"shirt, blouse", "jacket", "tie", "coat", "sweater", "cardigan", "coat"
}
return any(item[0] in formal_workplace_attire for item in detected_items)
@app.post("/process")
async def process_file(file_data: FileProcess):
logger.info(f"Processing file from Firebase Storage: {file_data.file_path}")
try:
# Get the file from Firebase Storage
blob = bucket.blob(file_data.file_path)
# Create a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file_data.file_path.split('.')[-1]}") as tmp_file:
blob.download_to_filename(tmp_file.name)
tmp_file_path = Path(tmp_file.name)
logger.info(f"File downloaded temporarily at: {tmp_file_path}")
file_type = file_data.file_path.split('.')[-1].lower()
try:
if file_type in ['mp4', 'avi', 'mov', 'wmv']:
output,testing = process_video(str(tmp_file_path))
result = {"type": "video", "data": {"result": output}}
else:
raise HTTPException(status_code=400, detail="Unsupported file type")
logger.info(f"Processing complete. Result: {result}")
# Store result in Firebase
try:
doc_ref = db.collection('results').add(result)
return {"message": "File processed successfully", "result": result}
except Exception as e:
logger.error(f"Failed to store result in Firebase: {str(e)}")
return {"message": "File processed successfully, but failed to store in Firebase", "result": result,
"error": str(e)}
finally:
# Clean up the temporary file
tmp_file_path.unlink()
except Exception as e:
logger.error(f"Error processing file: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error processing file: {str(e)}")
def process_video(video_path,num_frames=10):
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_indices = sorted(random.sample(range(total_frames), min(num_frames, total_frames)))
compliance_results = []
for frame_index in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
ret, frame = cap.read()
if ret:
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
detected_items = detect_fashion(image)
is_compliant = check_dress_code(detected_items)
compliance_results.append(is_compliant)
cap.release()
average_compliance = sum(compliance_results) / len(compliance_results)
return average_compliance, compliance_results
if __name__ == "__main__":
logger.info("Starting the Face Emotion Recognition API")
uvicorn.run(app, host="0.0.0.0", port=8000)