|
from config_loconet import LoCoNetConfig |
|
from transformers import PreTrainedModel |
|
from loconet_encoder import locoencoder |
|
from loss_multi import lossAV, lossA, lossV |
|
|
|
|
|
class loconet(PreTrainedModel): |
|
config_class = LoCoNetConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.model = locoencoder(config) |
|
|
|
def forward(self, audioFeature, visualFeature, masks, labels=None): |
|
b, s, t = visualFeature.shape[:3] |
|
visualFeature = visualFeature.view(b * s, *visualFeature.shape[2:]) |
|
labels = labels.view(b * s, *labels.shape[2:]) |
|
masks = masks.view(b * s, *masks.shape[2:]) |
|
|
|
audioEmbed = self.model.forward_audio_frontend(audioFeature) |
|
visualEmbed = self.model.forward_visual_frontend(visualFeature) |
|
audioEmbed = audioEmbed.repeat(s, 1, 1) |
|
|
|
audioEmbed, visualEmbed = self.model.forward_cross_attention(audioEmbed, visualEmbed) |
|
outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed, b, s) |
|
outsA = self.model.forward_audio_backend(audioEmbed) |
|
outsV = self.model.forward_visual_backend(visualEmbed) |
|
num_frames = masks.sum() |
|
|
|
if labels is not None: |
|
|
|
labels = labels.reshape((-1)) |
|
masks = masks.reshape((-1)) |
|
nlossAV, _, _, prec = self.lossAV.forward(outsAV, labels, masks) |
|
nlossA = self.lossA.forward(outsA, labels, masks) |
|
nlossV = self.lossV.forward(outsV, labels, masks) |
|
|
|
nloss = nlossAV + 0.4 * nlossA + 0.4 * nlossV |
|
|
|
return {"loss": nloss, "logits": outsAV} |
|
|
|
else: |
|
|
|
return {"logits": outsAV} |
|
|