LoCoNet_ASD / modeling_loconet.py
Superxixixi's picture
Upload loconet
6f6705b
raw
history blame contribute delete
No virus
1.76 kB
from config_loconet import LoCoNetConfig
from transformers import PreTrainedModel
from loconet_encoder import locoencoder
from loss_multi import lossAV, lossA, lossV
class loconet(PreTrainedModel):
config_class = LoCoNetConfig
def __init__(self, config):
super().__init__(config)
self.model = locoencoder(config)
self.lossAV = lossAV()
self.lossA = lossA()
self.lossV = lossV()
def forward(self, audioFeature, visualFeature, masks, labels=None):
b, s, t = visualFeature.shape[:3]
visualFeature = visualFeature.view(b * s, *visualFeature.shape[2:])
labels = labels.view(b * s, *labels.shape[2:])
masks = masks.view(b * s, *masks.shape[2:])
audioEmbed = self.model.forward_audio_frontend(audioFeature) # B, C, T, 4
visualEmbed = self.model.forward_visual_frontend(visualFeature)
audioEmbed = audioEmbed.repeat(s, 1, 1)
audioEmbed, visualEmbed = self.model.forward_cross_attention(audioEmbed, visualEmbed)
outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed, b, s)
outsA = self.model.forward_audio_backend(audioEmbed)
outsV = self.model.forward_visual_backend(visualEmbed)
num_frames = masks.sum()
if labels is not None:
labels = labels.reshape((-1))
masks = masks.reshape((-1))
nlossAV, _, _, prec = self.lossAV.forward(outsAV, labels, masks)
nlossA = self.lossA.forward(outsA, labels, masks)
nlossV = self.lossV.forward(outsV, labels, masks)
nloss = nlossAV + 0.4 * nlossA + 0.4 * nlossV
return {"loss": nloss, "logits": outsAV}
else:
return {"logits": outsAV}