from config_loconet import LoCoNetConfig from transformers import PreTrainedModel from loconet_encoder import locoencoder from loss_multi import lossAV, lossA, lossV class loconet(PreTrainedModel): config_class = LoCoNetConfig def __init__(self, config): super().__init__(config) self.model = locoencoder(config) self.lossAV = lossAV() self.lossA = lossA() self.lossV = lossV() def forward(self, audioFeature, visualFeature, masks, labels=None): b, s, t = visualFeature.shape[:3] visualFeature = visualFeature.view(b * s, *visualFeature.shape[2:]) labels = labels.view(b * s, *labels.shape[2:]) masks = masks.view(b * s, *masks.shape[2:]) audioEmbed = self.model.forward_audio_frontend(audioFeature) # B, C, T, 4 visualEmbed = self.model.forward_visual_frontend(visualFeature) audioEmbed = audioEmbed.repeat(s, 1, 1) audioEmbed, visualEmbed = self.model.forward_cross_attention(audioEmbed, visualEmbed) outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed, b, s) outsA = self.model.forward_audio_backend(audioEmbed) outsV = self.model.forward_visual_backend(visualEmbed) num_frames = masks.sum() if labels is not None: labels = labels.reshape((-1)) masks = masks.reshape((-1)) nlossAV, _, _, prec = self.lossAV.forward(outsAV, labels, masks) nlossA = self.lossA.forward(outsA, labels, masks) nlossV = self.lossV.forward(outsV, labels, masks) nloss = nlossAV + 0.4 * nlossA + 0.4 * nlossV return {"loss": nloss, "logits": outsAV} else: return {"logits": outsAV}