| import torch |
| import torch.nn as nn |
| from transformers import ViTModel, ViTConfig |
| from torchvision import transforms |
| import cv2 |
| import numpy as np |
|
|
| class VisionEmotionModel(nn.Module): |
| """ |
| Vision Transformer for facial emotion recognition. |
| Fine-tuned on FER-2013/AffectNet datasets. |
| """ |
| def __init__(self, num_emotions=7, pretrained=True): |
| super().__init__() |
| self.num_emotions = num_emotions |
|
|
| |
| if pretrained: |
| self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224') |
| else: |
| config = ViTConfig() |
| self.vit = ViTModel(config) |
|
|
| |
| for param in self.vit.parameters(): |
| param.requires_grad = False |
|
|
| |
| self.emotion_classifier = nn.Sequential( |
| nn.Linear(self.vit.config.hidden_size, 512), |
| nn.ReLU(), |
| nn.Dropout(0.3), |
| nn.Linear(512, num_emotions) |
| ) |
|
|
| |
| self.confidence_head = nn.Sequential( |
| nn.Linear(self.vit.config.hidden_size, 256), |
| nn.ReLU(), |
| nn.Linear(256, 1), |
| nn.Sigmoid() |
| ) |
|
|
| |
| self.transform = transforms.Compose([ |
| transforms.ToPILImage(), |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
|
|
| def forward(self, x): |
| """ |
| x: batch of images (B, C, H, W) or list of face crops |
| Returns: emotion_logits, confidence |
| """ |
| if isinstance(x, list): |
| |
| batch = torch.stack([self.transform(img) for img in x]) |
| else: |
| batch = x |
|
|
| outputs = self.vit(pixel_values=batch) |
| cls_token = outputs.last_hidden_state[:, 0, :] |
|
|
| emotion_logits = self.emotion_classifier(cls_token) |
| confidence = self.confidence_head(cls_token) |
|
|
| return emotion_logits, confidence.squeeze() |
|
|
| def detect_faces(self, frame): |
| """ |
| Detect faces in a video frame using OpenCV. |
| Returns list of face crops. |
| """ |
| face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') |
| gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) |
| faces = face_cascade.detectMultiScale(gray, 1.1, 4) |
|
|
| face_crops = [] |
| for (x, y, w, h) in faces: |
| face = frame[y:y+h, x:x+w] |
| if face.size > 0: |
| face_crops.append(face) |
|
|
| return face_crops |
|
|
| def extract_features(self, faces): |
| """ |
| Extract emotion features from detected faces. |
| """ |
| if not faces: |
| return None, None |
|
|
| with torch.no_grad(): |
| emotion_logits, confidence = self.forward(faces) |
|
|
| return emotion_logits, confidence |