import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms, models import numpy as np from PIL import Image import matplotlib.pyplot as plt import os # Class Mapping for RAF-DB Dataset (7 classes) class_mapping = { 0: "Surprise", 1: "Fear", 2: "Disgust", 3: "Happiness", 4: "Sadness", 5: "Anger", 6: "Neutral" } # Transformations for inference (same as test transform) transform = transforms.Compose([ transforms.Resize((112, 112)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Feature Extraction Backbone class IR50(nn.Module): def __init__(self): super(IR50, self).__init__() resnet = models.resnet50(weights='IMAGENET1K_V1') self.conv1 = resnet.conv1 self.bn1 = resnet.bn1 self.relu = resnet.relu self.maxpool = resnet.maxpool self.layer1 = resnet.layer1 self.layer2 = resnet.layer2 self.downsample = nn.Conv2d(512, 256, 1, stride=2) self.bn_downsample = nn.BatchNorm2d(256, eps=1e-5) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.downsample(x) x = self.bn_downsample(x) return x # HLA Stream class HLA(nn.Module): def __init__(self, in_channels=256, reduction=4): super(HLA, self).__init__() reduced_channels = in_channels // reduction self.spatial_branch1 = nn.Conv2d(in_channels, reduced_channels, 1) self.spatial_branch2 = nn.Conv2d(in_channels, reduced_channels, 1) self.sigmoid = nn.Sigmoid() self.channel_restore = nn.Conv2d(reduced_channels, in_channels, 1) self.channel_attention = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False), nn.ReLU(), nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False), nn.Sigmoid() ) self.bn = nn.BatchNorm2d(in_channels, eps=1e-5) def forward(self, x): b1 = self.spatial_branch1(x) b2 = self.spatial_branch2(x) spatial_attn = self.sigmoid(torch.max(b1, b2)) spatial_attn = self.channel_restore(spatial_attn) spatial_out = x * spatial_attn channel_attn = self.channel_attention(spatial_out) out = spatial_out * channel_attn out = self.bn(out) return out # ViT Stream class ViT(nn.Module): def __init__(self, in_channels=256, patch_size=1, embed_dim=768, num_layers=8, num_heads=12): # 8 layers as in the 82.93% version super(ViT, self).__init__() self.patch_embed = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) num_patches = (7 // patch_size) * (7 // patch_size) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) self.transformer = nn.ModuleList([ nn.TransformerEncoderLayer(embed_dim, num_heads, dim_feedforward=1536, activation="gelu") for _ in range(num_layers) ]) self.ln = nn.LayerNorm(embed_dim) self.bn = nn.BatchNorm1d(embed_dim, eps=1e-5) # Initialize weights nn.init.xavier_uniform_(self.patch_embed.weight) nn.init.zeros_(self.patch_embed.bias) nn.init.normal_(self.cls_token, std=0.02) nn.init.normal_(self.pos_embed, std=0.02) def forward(self, x): x = self.patch_embed(x) x = x.flatten(2).transpose(1, 2) cls_tokens = self.cls_token.expand(x.size(0), -1, -1) x = torch.cat([cls_tokens, x], dim=1) x = x + self.pos_embed for layer in self.transformer: x = layer(x) x = x[:, 0] x = self.ln(x) x = self.bn(x) return x # Intensity Stream class IntensityStream(nn.Module): def __init__(self, in_channels=256): super(IntensityStream, self).__init__() sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32) sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32) self.sobel_x = nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=False, groups=in_channels) self.sobel_y = nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=False, groups=in_channels) self.sobel_x.weight.data = sobel_x.repeat(in_channels, 1, 1, 1) self.sobel_y.weight.data = sobel_y.repeat(in_channels, 1, 1, 1) self.conv = nn.Conv2d(in_channels, 128, 3, padding=1) self.bn = nn.BatchNorm2d(128, eps=1e-5) self.pool = nn.AdaptiveAvgPool2d(1) self.attention = nn.MultiheadAttention(embed_dim=128, num_heads=1) # Initialize weights nn.init.xavier_uniform_(self.conv.weight) nn.init.zeros_(self.conv.bias) def forward(self, x): gx = self.sobel_x(x) gy = self.sobel_y(x) grad_magnitude = torch.sqrt(gx**2 + gy**2 + 1e-8) variance = ((x - x.mean(dim=1, keepdim=True))**2).mean(dim=1).flatten(1) cnn_out = F.relu(self.conv(grad_magnitude)) cnn_out = self.bn(cnn_out) texture_out = self.pool(cnn_out).squeeze(-1).squeeze(-1) attn_in = cnn_out.flatten(2).permute(2, 0, 1) attn_in = attn_in / (attn_in.norm(dim=-1, keepdim=True) + 1e-8) attn_out, _ = self.attention(attn_in, attn_in, attn_in) context_out = attn_out.mean(dim=0) out = torch.cat([texture_out, context_out], dim=1) return out, grad_magnitude, variance # Full Model (Single-Label Prediction) class TripleStreamHLAViT(nn.Module): def __init__(self, num_classes=7): super(TripleStreamHLAViT, self).__init__() self.backbone = IR50() self.hla = HLA() self.vit = ViT() self.intensity = IntensityStream() self.fc_hla = nn.Linear(256*7*7, 768) self.fc_intensity = nn.Linear(256, 768) self.fusion_fc = nn.Linear(768*3, 512) self.bn_fusion = nn.BatchNorm1d(512, eps=1e-5) self.dropout = nn.Dropout(0.5) self.classifier = nn.Linear(512, num_classes) # Initialize weights nn.init.xavier_uniform_(self.fc_hla.weight) nn.init.zeros_(self.fc_hla.bias) nn.init.xavier_uniform_(self.fc_intensity.weight) nn.init.zeros_(self.fc_intensity.bias) nn.init.xavier_uniform_(self.fusion_fc.weight) nn.init.zeros_(self.fusion_fc.bias) nn.init.xavier_uniform_(self.classifier.weight) nn.init.zeros_(self.classifier.bias) def forward(self, x): features = self.backbone(x) hla_out = self.hla(features) vit_out = self.vit(features) intensity_out, grad_magnitude, variance = self.intensity(features) hla_flat = self.fc_hla(hla_out.view(-1, 256*7*7)) intensity_flat = self.fc_intensity(intensity_out) fused = torch.cat([hla_flat, vit_out, intensity_flat], dim=1) fused = F.relu(self.fusion_fc(fused)) fused = self.bn_fusion(fused) fused = self.dropout(fused) logits = self.classifier(fused) return logits, hla_out, vit_out, grad_magnitude, variance # Load the model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") model = TripleStreamHLAViT(num_classes=7).to(device) model_path = "triple_stream_model_rafdb.pth" # Ensure this file is in the Hugging Face Space repository try: # Map the weights to the appropriate device map_location = torch.device('cpu') if not torch.cuda.is_available() else None model.load_state_dict(torch.load(model_path, map_location=map_location, weights_only=True)) model.eval() print("Model loaded successfully") except Exception as e: print(f"Error loading model: {e}") raise # Inference and Visualization Function def predict_emotion(image): # Convert the input image (from Gradio) to PIL Image if isinstance(image, np.ndarray): image = Image.fromarray(image) # Preprocess the image image_tensor = transform(image).unsqueeze(0).to(device) # Run inference with torch.no_grad(): outputs, hla_out, _, grad_magnitude, _ = model(image_tensor) probs = F.softmax(outputs, dim=1) pred_label = torch.argmax(probs, dim=1).item() pred_label_name = class_mapping[pred_label] probabilities = probs.cpu().numpy()[0] # Create probability dictionary prob_dict = {class_mapping[i]: float(prob) for i, prob in enumerate(probabilities)} # Generate HLA heatmap heatmap = hla_out[0].mean(dim=0).detach().cpu().numpy() # Denormalize the image for visualization img = image_tensor[0].permute(1, 2, 0).detach().cpu().numpy() img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]) img = np.clip(img, 0, 1) # Plot the input image and heatmap fig, axs = plt.subplots(1, 2, figsize=(8, 4)) axs[0].imshow(img) axs[0].set_title(f"Input Image\nPredicted: {pred_label_name}") axs[0].axis("off") axs[1].imshow(heatmap, cmap="jet") axs[1].set_title("HLA Heatmap") axs[1].axis("off") plt.tight_layout() # Save the plot to a temporary file plt.savefig("visualization.png") plt.close() return pred_label_name, prob_dict, "visualization.png" # Gradio Interface iface = gr.Interface( fn=predict_emotion, inputs=gr.Image(type="pil", label="Upload an Image"), outputs=[ gr.Textbox(label="Predicted Emotion"), gr.Label(label="Probabilities"), gr.Image(label="Input Image and HLA Heatmap") ], title="Facial Emotion Recognition with TripleStreamHLAViT", description="Upload an image to predict the facial emotion (Surprise, Fear, Disgust, Happiness, Sadness, Anger, Neutral). This model achieves 82.93% test accuracy on the RAF-DB dataset. The HLA heatmap shows where the model focuses.", examples=[ ["examples/surprise.jpg"], ["examples/sadness.jpg"] ] ) # Launch the interface if __name__ == "__main__": iface.launch(share=False)