Spaces:
Running
Running
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
import yaml | |
import argparse | |
from modules.BEATs.BEATs import BEATs, BEATsConfig | |
from modules.AudioToken.embedder import FGAEmbedder | |
from modules.CLIPSeg.clipseg_for_audio import CLIPSeg | |
from modules.mask_utils import ImageMasker, FeatureMasker | |
from transformers import AutoTokenizer | |
class ACL(nn.Module): | |
def __init__(self, conf_file: str, device: str): | |
""" | |
Audio-Grounded Contrastive Learning (ACL) model. | |
Args: | |
conf_file (str): Path to the configuration file. | |
device (str): Device to move the model to. | |
""" | |
super(ACL, self).__init__() | |
# Get configuration | |
with open(conf_file) as f: | |
config = yaml.load(f, Loader=yaml.FullLoader) | |
self.args = argparse.Namespace() | |
self.args.model = argparse.Namespace(**config['model']) | |
self.args.clip_embedding_dim = config['clip_conf'][self.args.model.clip]['embedding_dim'] | |
self.args.clip_name = config['clip_conf'][self.args.model.clip]['name'] | |
self.pretrain = argparse.Namespace(**config['pretrain']) | |
self.args.audio_proj = argparse.Namespace(**config['fga_conf'][self.args.model.audio_proj]) | |
# Init audio encoder | |
checkpoint = torch.load(self.pretrain.audio_backbone) | |
cfg = BEATsConfig(checkpoint['cfg']) | |
self.audio_backbone = BEATs(cfg) | |
# Text Tokenizer for placeholder prompt | |
self.tokenizer = AutoTokenizer.from_pretrained("CIDAS/clipseg-rd64-refined") | |
# Init audio projection layer | |
self.audio_proj = FGAEmbedder(input_size=self.args.audio_proj.input_size * 3, | |
output_size=self.args.audio_proj.output_size) | |
# Init audio-visual grounder (Grounder: CLIPSeg) | |
self.av_grounder = CLIPSeg.from_pretrained("CIDAS/clipseg-rd64-refined") | |
# Init maskers | |
self.masker_i = ImageMasker(10.0, 14.0, 1.0) | |
self.masker_f = FeatureMasker(0.5, 0.07) | |
# Load weights | |
self.audio_backbone.load_state_dict(checkpoint['model']) | |
self.audio_backbone.predictor = None | |
if self.pretrain.audio_proj is not None: | |
self.audio_proj.load_state_dict(torch.load(self.pretrain.audio_embedder)) | |
# Set device | |
self.device = device | |
self.audio_backbone.to(device=self.device) | |
self.av_grounder.to(device=self.device) | |
self.audio_proj.to(device=self.device) | |
self.masker_i.to(self.device) | |
self.masker_f.to(self.device) | |
def get_placeholder_token(self, prompt_text: str): | |
""" | |
Get placeholder token from prompt text | |
Args: | |
prompt_text (str): prompt text without '{}' | |
Returns: | |
CLIPTokenizerFast result with prompt text | |
""" | |
placeholder_token = self.tokenizer(prompt_text, return_tensors="pt").data['input_ids'] | |
placeholder_token = F.pad(placeholder_token, (0, 77 - placeholder_token.shape[-1])).to(self.device) | |
return placeholder_token | |
def train(self, bool: bool = True): | |
""" | |
Set the module in training mode. | |
Args: | |
bool (bool): If True, set the module in training mode. | |
""" | |
super().train(bool) | |
self.av_grounder.requires_grad_(False) | |
self.audio_backbone.requires_grad_(False) | |
def encode_audio(self, audio: torch.Tensor, placeholder_token: torch.Tensor, pos: int, | |
prompt_size: int) -> torch.Tensor: | |
""" | |
Encode audio input into audio-driven embedding (Audio-Driven Embedder) | |
Args: | |
audio (torch.Tensor): Input audio tensor. | |
placeholder_token (torch.Tensor): Placeholder token for CLIP Text encoder. | |
pos (int): Position of audio token. | |
prompt_size (int): Size of the placeholder prompt. | |
Returns: | |
torch.Tensor: Audio-driven embeddings. | |
""" | |
audio_feat = self.audio_backbone.extract_features(audio)[1] | |
audio_token_emb = self.audio_proj(audio_feat).unsqueeze(1) | |
audio_driven_embedding = self.av_grounder.encode_audio(placeholder_token, audio_token_emb, pos, | |
prompt_size + audio_token_emb.shape[1]) | |
return audio_driven_embedding | |
def encode_vision(self, image: torch.Tensor) -> torch.Tensor: | |
""" | |
Encode visual input and generate visual embeddings. | |
Args: | |
image (torch.Tensor): Input image tensor. | |
Returns: | |
torch.Tensor: Visual embeddings. | |
""" | |
vision_outputs = self.av_grounder.clip.vision_model(pixel_values=image, | |
output_attentions=None, | |
output_hidden_states=True, | |
return_dict=True) | |
pooled_output = self.av_grounder.clip.visual_projection(vision_outputs[1]) | |
return pooled_output | |
def forward_decoder(self, image: torch.Tensor, embedding: torch.Tensor, resolution: int = 224) -> torch.Tensor: | |
""" | |
Forward pass of audio-visual grounder | |
Args: | |
image (torch.Tensor): Input image tensor. | |
embedding (torch.Tensor): Condition embedding tensor for grounder. | |
resolution (int): Resolution of the output. | |
ignore_indices (list): List of indices to ignore. | |
Returns: | |
torch.Tensor: Logits from the decoder. | |
""" | |
# step 1: forward the query images through the frozen CLIP vision encoder | |
vision_outputs = self.av_grounder.clip.vision_model(pixel_values=image, | |
output_attentions=None, | |
output_hidden_states=True, | |
return_dict=True) | |
hidden_states = vision_outputs.hidden_states | |
# we add +1 here as the hidden states also include the initial embeddings | |
activations = [hidden_states[i + 1] for i in self.av_grounder.extract_layers] | |
# step 2: compute conditional embeddings, either from text, images or an own provided embedding | |
# Audio injected embedding from input argument | |
# step 3: forward both the pooled output and the activations through the lightweight decoder to predict masks | |
decoder_outputs = self.av_grounder.decoder( | |
activations, | |
embedding, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=True, | |
) | |
logits = decoder_outputs.logits | |
if logits.ndim == 2: | |
logits = logits.unsqueeze(0).unsqueeze(1) | |
else: | |
logits = logits.unsqueeze(1) | |
B, c, h, w = image.shape | |
if (h, w) != (resolution, resolution): | |
logits = F.interpolate(logits, resolution, mode='bicubic') | |
return logits | |
def forward_module(self, image: torch.Tensor, embedding: torch.Tensor, resolution: int = 224, | |
force_comb: bool = False) -> torch.Tensor: | |
""" | |
Forward pass through the module. | |
Args: | |
image (torch.Tensor): Input image tensor. | |
embedding (torch.Tensor): Condition embedding tensor for grounder. | |
resolution (int): Resolution of the output tensor. | |
force_comb (bool): If True, force to get logits with all combination audio and image. | |
Returns: | |
torch.Tensor: Logits from the decoder. | |
""" | |
# N image, 1 embedding case -> [B_i, h, w] | |
if embedding.shape[0] != image.shape[0] and embedding.shape[0] == 1: | |
embeddings = embedding.repeat(image.shape[0], 1) | |
logits = self.forward_decoder(image, embeddings, resolution) | |
# N image, M embedding case -> [B_i, B_e, h, w] | |
elif embedding.shape[0] != image.shape[0] and embedding.shape[0] != 1 and image.shape[0] != 1 or force_comb: | |
logit_list = [] | |
for i in range(embedding.shape[0]): | |
embeddings = embedding[i].unsqueeze(0).repeat(image.shape[0], 1) | |
logit_list.append(self.forward_decoder(image, embeddings, resolution)) | |
logits = torch.cat(logit_list, dim=1) | |
# N image, N embedding or 1 image, N embedding -> [B_e, h, w] | |
else: | |
logits = self.forward_decoder(image, embedding, resolution) | |
return logits | |
def encode_masked_vision(self, image: torch.Tensor, embedding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, float, float]: | |
""" | |
Encode masked visual feature both image-level and feature-level. | |
Args: | |
image (torch.Tensor): Input image tensor. | |
embedding (torch.Tensor): Condition embedding tensor for grounder. | |
Returns: | |
tuple[torch.Tensor, torch.Tensor, float, float]: Feature masked embeddings, masked image embeddings, positive area, negative area. | |
""" | |
B, c, h, w = image.shape | |
maskclip_feat = self.av_grounder.get_pixels(image) # v^D: [B, c, h, w] | |
clipseg_mask = self.forward_module(image, embedding, h, force_comb=True) # M^G: [B, B, H, W] | |
# Area | |
area_matrix = self.masker_i(clipseg_mask).mean((2, 3)) | |
positive_area = area_matrix.diagonal().mean() | |
negative_area = area_matrix.mean() - positive_area / B | |
# Feature level masker | |
feature_mask = F.interpolate(self.masker_f(clipseg_mask), maskclip_feat.shape[2]) | |
# Image level masker | |
ind = torch.arange(B).to(image.device) | |
image_mask = self.masker_i(clipseg_mask[ind, ind].unsqueeze(1)) # Positive pair only | |
feature_masked_emb = torch.einsum('bchw,bnhw->bnc', maskclip_feat, feature_mask) / (feature_mask.sum() + 1e-6) | |
# step 1: forward the query images through the frozen CLIP vision encoder | |
masked_vision_outputs = self.av_grounder.clip.vision_model(pixel_values=image * image_mask, | |
output_attentions=None, | |
output_hidden_states=True, | |
return_dict=True) | |
masked_image_emb = self.av_grounder.clip.visual_projection(masked_vision_outputs[1]) | |
return feature_masked_emb, masked_image_emb, positive_area, negative_area | |
def forward(self, image: torch.Tensor, embedding: torch.Tensor, resolution: int = 224) -> dict: | |
""" | |
Forward pass of ACL model. | |
Args: | |
image (torch.Tensor): Input image tensor. | |
embedding (torch.Tensor): Condition embedding tensor for grounder. | |
resolution (int): Resolution of the output tensor. | |
Returns: | |
dict: Output dictionary containing relevant tensors. | |
""" | |
if self.training: | |
# seg_logit = self.forward_module(image, embedding, resolution) | |
v_f, v_i, p_area, n_area = self.encode_masked_vision(image, embedding) | |
out_dict = {'v_f': v_f, 'v_i': v_i, 'p_area': p_area, 'n_area': n_area} | |
else: | |
seg_logit = self.forward_module(image, embedding, resolution) | |
heatmap = self.masker_i(seg_logit, infer=True) | |
out_dict = {'heatmap': heatmap} | |
return out_dict | |
def save(self, model_dir: str): | |
""" | |
Save model parameters to a file. (Only trainable parts) | |
Args: | |
model_dir (str): Directory to save the model. | |
""" | |
ckp = {'audio_proj': self.audio_proj.state_dict(), 'masker_i': self.masker_i.state_dict()} | |
torch.save(ckp, model_dir) | |
def load(self, model_dir: str): | |
""" | |
Load model parameters from a file. (Only trainable parts) | |
Args: | |
model_dir (str): Directory to load the model from. | |
""" | |
ckp = torch.load(model_dir, map_location=self.device) | |
self.audio_proj.load_state_dict(ckp['audio_proj']) | |
self.masker_i.load_state_dict(ckp['masker_i']) | |