from fastai.vision.all import * from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware import uvicorn import logging import tempfile from pathlib import Path import firebase_admin from firebase_admin import credentials, firestore, storage from pydantic import BaseModel import torch from transformers import AutoImageProcessor, AutoModelForObjectDetection from PIL import Image, ImageDraw, ImageFont import cv2 import random # Load model and processor processor = AutoImageProcessor.from_pretrained("valentinafeve/yolos-fashionpedia") model = AutoModelForObjectDetection.from_pretrained("valentinafeve/yolos-fashionpedia") # Fashionpedia categories FASHION_CATEGORIES = [ 'shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 'skirt', 'coat', 'dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove', 'watch', 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'bag, wallet', 'scarf', 'umbrella', 'hood', 'collar', 'lapel', 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow', 'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel' ] def detect_fashion(image): inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) # Convert outputs (bounding boxes and class logits) to COCO API target_sizes = torch.tensor([image.size[::-1]]) results = processor.post_process_object_detection(outputs, threshold=0.1, target_sizes=target_sizes)[0] detected_items = [] for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): if score > 0.5: # Adjust this threshold as needed detected_items.append((FASHION_CATEGORIES[label], score.item(), box.tolist())) return detected_items def check_dress_code(detected_items): formal_workplace_attire = { "shirt, blouse", "jacket", "tie", "coat", "sweater", "cardigan", "coat" } return any(item[0] in formal_workplace_attire for item in detected_items) @app.post("/process") async def process_file(file_data: FileProcess): logger.info(f"Processing file from Firebase Storage: {file_data.file_path}") try: # Get the file from Firebase Storage blob = bucket.blob(file_data.file_path) # Create a temporary file with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file_data.file_path.split('.')[-1]}") as tmp_file: blob.download_to_filename(tmp_file.name) tmp_file_path = Path(tmp_file.name) logger.info(f"File downloaded temporarily at: {tmp_file_path}") file_type = file_data.file_path.split('.')[-1].lower() try: if file_type in ['mp4', 'avi', 'mov', 'wmv']: output,testing = process_video(str(tmp_file_path)) result = {"type": "video", "data": {"result": output}} else: raise HTTPException(status_code=400, detail="Unsupported file type") logger.info(f"Processing complete. Result: {result}") # Store result in Firebase try: doc_ref = db.collection('results').add(result) return {"message": "File processed successfully", "result": result} except Exception as e: logger.error(f"Failed to store result in Firebase: {str(e)}") return {"message": "File processed successfully, but failed to store in Firebase", "result": result, "error": str(e)} finally: # Clean up the temporary file tmp_file_path.unlink() except Exception as e: logger.error(f"Error processing file: {str(e)}") raise HTTPException(status_code=500, detail=f"Error processing file: {str(e)}") def process_video(video_path,num_frames=10): cap = cv2.VideoCapture(video_path) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) frame_indices = sorted(random.sample(range(total_frames), min(num_frames, total_frames))) compliance_results = [] for frame_index in frame_indices: cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) ret, frame = cap.read() if ret: image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) detected_items = detect_fashion(image) is_compliant = check_dress_code(detected_items) compliance_results.append(is_compliant) cap.release() average_compliance = sum(compliance_results) / len(compliance_results) return average_compliance, compliance_results if __name__ == "__main__": logger.info("Starting the Face Emotion Recognition API") uvicorn.run(app, host="0.0.0.0", port=8000)