""" File: model.py Author: Elena Ryumina and Dmitry Ryumin Description: This module provides functions for loading and processing a pre-trained deep learning model for facial expression recognition. License: MIT License """ import torch import requests # Importing necessary components for the Gradio app from app.config import config_data from app.model_architectures import ResNet50, LSTMPyTorch, ExprModelV3 from transformers import AutoFeatureExtractor device = "cuda" if torch.cuda.is_available() else "cpu" def load_model(model_url, model_path): try: with requests.get(model_url, stream=True) as response: with open(model_path, "wb") as file: for chunk in response.iter_content(chunk_size=8192): file.write(chunk) return model_path except Exception as e: print(f"Error loading model: {e}") return None gradients = {} def get_gradients(name): def hook(model, input, output): gradients[name] = output return hook activations = {} def get_activations(name): def hook(model, input, output): activations[name] = output.detach() return hook test_static = torch.rand(1, 3, 224, 224) test_dynamic = torch.rand(1, 10, 512) test_audio = torch.rand(1, 64000) path_static = load_model(config_data.model_static_url, config_data.model_static_path) pth_model_static = ResNet50(7, channels=3) pth_model_static.load_state_dict(torch.load(path_static)) pth_model_static.to(device) pth_model_static.eval() pth_model_static(test_static.to(device)) pth_model_static.layer4.register_full_backward_hook(get_gradients('layer4')) pth_model_static.layer4.register_forward_hook(get_activations('layer4')) pth_model_static.fc1.register_forward_hook(get_activations('features')) path_dynamic = load_model(config_data.model_dynamic_url, config_data.model_dynamic_path) pth_model_dynamic = LSTMPyTorch() pth_model_dynamic.load_state_dict(torch.load(path_dynamic)) pth_model_dynamic.to(device) pth_model_dynamic.eval() pth_model_dynamic(test_dynamic.to(device)) path_audio_model_1 = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim" path_audio_model_2 = load_model(config_data.model_audio_url, config_data.model_audio_path) audio_processor = AutoFeatureExtractor.from_pretrained(path_audio_model_1) audio_model = ExprModelV3.from_pretrained(path_audio_model_1) audio_model.load_state_dict(torch.load(path_audio_model_2)["model_state_dict"]) audio_model.to(device) audio_model.eval() audio_model(test_audio.to(device))