Spaces:
Running
Running
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| from mtcnn import MTCNN | |
| from collections import deque, Counter | |
| from PIL import Image | |
| LOCAL_MODEL_PATH = "sota_model" | |
| class HFPredictor: | |
| def __init__(self, smoothing_window=10, confidence_threshold=0.3): | |
| print(f"[PREDICTOR INFO] Loading model from local path: {LOCAL_MODEL_PATH}...") | |
| self.processor = AutoImageProcessor.from_pretrained(LOCAL_MODEL_PATH) | |
| self.model = AutoModelForImageClassification.from_pretrained(LOCAL_MODEL_PATH) | |
| self.face_detector = MTCNN() | |
| self.classes = list(self.model.config.id2label.values()) | |
| self.confidence_threshold = confidence_threshold | |
| self.recent_predictions = deque(maxlen=smoothing_window) | |
| self.stable_prediction = "---" | |
| print("[PREDICTOR INFO] Predictor initialized successfully.") | |
| def process_frame(self, frame): | |
| """ | |
| Processes a single frame. This function is now used for ALL predictions | |
| (live, image, and video) to ensure consistency. | |
| """ | |
| if frame is None: return frame, {} | |
| annotated_frame = frame.copy() | |
| all_probabilities = {} | |
| faces = self.face_detector.detect_faces(frame) | |
| for face in faces: | |
| x, y, width, height = face['box'] | |
| x, y = max(0, x), max(0, y) | |
| face_roi = frame[y:y+height, x:x+width] | |
| if face_roi.size > 0: | |
| pil_image = Image.fromarray(face_roi) | |
| inputs = self.processor(images=pil_image, return_tensors="pt") | |
| with torch.no_grad(): | |
| logits = self.model(**inputs).logits | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| predictions = probs[0].numpy() | |
| pred_index = np.argmax(predictions) | |
| confidence = predictions[pred_index] | |
| # --- THIS IS THE DEFINITIVE FIX --- | |
| # For the bounding box text, we determine which label to show. | |
| # For the live feed, we want smooth predictions. For static images, we want the direct one. | |
| # A simple check on the deque can tell us if we are in a "live" context. | |
| if len(self.recent_predictions) > 0: | |
| # If the deque has items, we are in a live stream, so use smoothing. | |
| if confidence > self.confidence_threshold: | |
| self.recent_predictions.append(pred_index) | |
| most_common_pred = Counter(self.recent_predictions).most_common(1)[0][0] | |
| display_emotion = self.classes[most_common_pred] | |
| else: | |
| # If the deque is empty, it's a static image/video, so use the direct prediction. | |
| display_emotion = self.classes[pred_index] | |
| # Reset the deque for the next live session if this was a static call | |
| if len(self.recent_predictions) == 0: | |
| self.recent_predictions.clear() | |
| text = f"{display_emotion} ({confidence*100:.1f}%)" | |
| # --- END FIX --- | |
| GREEN = (0, 255, 0); BLACK = (0, 0, 0); FONT = cv2.FONT_HERSHEY_SIMPLEX | |
| (text_width, text_height), baseline = cv2.getTextSize(text, FONT, 0.8, 2) | |
| cv2.rectangle(annotated_frame, (x, y - text_height - baseline - 10), (x + text_width + 10, y), GREEN, cv2.FILLED) | |
| cv2.putText(annotated_frame, text, (x + 5, y - 5), FONT, 0.8, BLACK, 2) | |
| cv2.rectangle(annotated_frame, (x, y), (x+width, y+height), GREEN, 3) | |
| all_probabilities = {self.classes[i]: float(predictions[i]) for i in range(len(self.classes))} | |
| return annotated_frame, all_probabilities |