AVCER / app /model.py
ElenaRyumina's picture
Summary
47aeb66
raw
history blame
2.55 kB
"""
File: model.py
Author: Elena Ryumina and Dmitry Ryumin
Description: This module provides functions for loading and processing a pre-trained deep learning model
for facial expression recognition.
License: MIT License
"""
import torch
import requests
# Importing necessary components for the Gradio app
from app.config import config_data
from app.model_architectures import ResNet50, LSTMPyTorch, ExprModelV3
from transformers import AutoFeatureExtractor
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_model(model_url, model_path):
try:
with requests.get(model_url, stream=True) as response:
with open(model_path, "wb") as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
return model_path
except Exception as e:
print(f"Error loading model: {e}")
return None
gradients = {}
def get_gradients(name):
def hook(model, input, output):
gradients[name] = output
return hook
activations = {}
def get_activations(name):
def hook(model, input, output):
activations[name] = output.detach()
return hook
test_static = torch.rand(1, 3, 224, 224)
test_dynamic = torch.rand(1, 10, 512)
test_audio = torch.rand(1, 64000)
path_static = load_model(config_data.model_static_url, config_data.model_static_path)
pth_model_static = ResNet50(7, channels=3)
pth_model_static.load_state_dict(torch.load(path_static))
pth_model_static.to(device)
pth_model_static.eval()
pth_model_static(test_static.to(device))
pth_model_static.layer4.register_full_backward_hook(get_gradients('layer4'))
pth_model_static.layer4.register_forward_hook(get_activations('layer4'))
pth_model_static.fc1.register_forward_hook(get_activations('features'))
path_dynamic = load_model(config_data.model_dynamic_url, config_data.model_dynamic_path)
pth_model_dynamic = LSTMPyTorch()
pth_model_dynamic.load_state_dict(torch.load(path_dynamic))
pth_model_dynamic.to(device)
pth_model_dynamic.eval()
pth_model_dynamic(test_dynamic.to(device))
path_audio_model_1 = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
path_audio_model_2 = load_model(config_data.model_audio_url, config_data.model_audio_path)
audio_processor = AutoFeatureExtractor.from_pretrained(path_audio_model_1)
audio_model = ExprModelV3.from_pretrained(path_audio_model_1)
audio_model.load_state_dict(torch.load(path_audio_model_2)["model_state_dict"])
audio_model.to(device)
audio_model.eval()
audio_model(test_audio.to(device))