swimmiing's picture
Fix cpu only v2
833fa47
import torch
import torch.nn.functional as F
from torch import nn
import yaml
import argparse
from modules.BEATs.BEATs import BEATs, BEATsConfig
from modules.AudioToken.embedder import FGAEmbedder
from modules.CLIPSeg.clipseg_for_audio import CLIPSeg
from modules.mask_utils import ImageMasker, FeatureMasker
from transformers import AutoTokenizer
class ACL(nn.Module):
def __init__(self, conf_file: str, device: str):
"""
Audio-Grounded Contrastive Learning (ACL) model.
Args:
conf_file (str): Path to the configuration file.
device (str): Device to move the model to.
"""
super(ACL, self).__init__()
# Get configuration
with open(conf_file) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
self.args = argparse.Namespace()
self.args.model = argparse.Namespace(**config['model'])
self.args.clip_embedding_dim = config['clip_conf'][self.args.model.clip]['embedding_dim']
self.args.clip_name = config['clip_conf'][self.args.model.clip]['name']
self.pretrain = argparse.Namespace(**config['pretrain'])
self.args.audio_proj = argparse.Namespace(**config['fga_conf'][self.args.model.audio_proj])
# Init audio encoder
checkpoint = torch.load(self.pretrain.audio_backbone)
cfg = BEATsConfig(checkpoint['cfg'])
self.audio_backbone = BEATs(cfg)
# Text Tokenizer for placeholder prompt
self.tokenizer = AutoTokenizer.from_pretrained("CIDAS/clipseg-rd64-refined")
# Init audio projection layer
self.audio_proj = FGAEmbedder(input_size=self.args.audio_proj.input_size * 3,
output_size=self.args.audio_proj.output_size)
# Init audio-visual grounder (Grounder: CLIPSeg)
self.av_grounder = CLIPSeg.from_pretrained("CIDAS/clipseg-rd64-refined")
# Init maskers
self.masker_i = ImageMasker(10.0, 14.0, 1.0)
self.masker_f = FeatureMasker(0.5, 0.07)
# Load weights
self.audio_backbone.load_state_dict(checkpoint['model'])
self.audio_backbone.predictor = None
if self.pretrain.audio_proj is not None:
self.audio_proj.load_state_dict(torch.load(self.pretrain.audio_embedder))
# Set device
self.device = device
self.audio_backbone.to(device=self.device)
self.av_grounder.to(device=self.device)
self.audio_proj.to(device=self.device)
self.masker_i.to(self.device)
self.masker_f.to(self.device)
def get_placeholder_token(self, prompt_text: str):
"""
Get placeholder token from prompt text
Args:
prompt_text (str): prompt text without '{}'
Returns:
CLIPTokenizerFast result with prompt text
"""
placeholder_token = self.tokenizer(prompt_text, return_tensors="pt").data['input_ids']
placeholder_token = F.pad(placeholder_token, (0, 77 - placeholder_token.shape[-1])).to(self.device)
return placeholder_token
def train(self, bool: bool = True):
"""
Set the module in training mode.
Args:
bool (bool): If True, set the module in training mode.
"""
super().train(bool)
self.av_grounder.requires_grad_(False)
self.audio_backbone.requires_grad_(False)
def encode_audio(self, audio: torch.Tensor, placeholder_token: torch.Tensor, pos: int,
prompt_size: int) -> torch.Tensor:
"""
Encode audio input into audio-driven embedding (Audio-Driven Embedder)
Args:
audio (torch.Tensor): Input audio tensor.
placeholder_token (torch.Tensor): Placeholder token for CLIP Text encoder.
pos (int): Position of audio token.
prompt_size (int): Size of the placeholder prompt.
Returns:
torch.Tensor: Audio-driven embeddings.
"""
audio_feat = self.audio_backbone.extract_features(audio)[1]
audio_token_emb = self.audio_proj(audio_feat).unsqueeze(1)
audio_driven_embedding = self.av_grounder.encode_audio(placeholder_token, audio_token_emb, pos,
prompt_size + audio_token_emb.shape[1])
return audio_driven_embedding
def encode_vision(self, image: torch.Tensor) -> torch.Tensor:
"""
Encode visual input and generate visual embeddings.
Args:
image (torch.Tensor): Input image tensor.
Returns:
torch.Tensor: Visual embeddings.
"""
vision_outputs = self.av_grounder.clip.vision_model(pixel_values=image,
output_attentions=None,
output_hidden_states=True,
return_dict=True)
pooled_output = self.av_grounder.clip.visual_projection(vision_outputs[1])
return pooled_output
def forward_decoder(self, image: torch.Tensor, embedding: torch.Tensor, resolution: int = 224) -> torch.Tensor:
"""
Forward pass of audio-visual grounder
Args:
image (torch.Tensor): Input image tensor.
embedding (torch.Tensor): Condition embedding tensor for grounder.
resolution (int): Resolution of the output.
ignore_indices (list): List of indices to ignore.
Returns:
torch.Tensor: Logits from the decoder.
"""
# step 1: forward the query images through the frozen CLIP vision encoder
vision_outputs = self.av_grounder.clip.vision_model(pixel_values=image,
output_attentions=None,
output_hidden_states=True,
return_dict=True)
hidden_states = vision_outputs.hidden_states
# we add +1 here as the hidden states also include the initial embeddings
activations = [hidden_states[i + 1] for i in self.av_grounder.extract_layers]
# step 2: compute conditional embeddings, either from text, images or an own provided embedding
# Audio injected embedding from input argument
# step 3: forward both the pooled output and the activations through the lightweight decoder to predict masks
decoder_outputs = self.av_grounder.decoder(
activations,
embedding,
output_attentions=None,
output_hidden_states=None,
return_dict=True,
)
logits = decoder_outputs.logits
if logits.ndim == 2:
logits = logits.unsqueeze(0).unsqueeze(1)
else:
logits = logits.unsqueeze(1)
B, c, h, w = image.shape
if (h, w) != (resolution, resolution):
logits = F.interpolate(logits, resolution, mode='bicubic')
return logits
def forward_module(self, image: torch.Tensor, embedding: torch.Tensor, resolution: int = 224,
force_comb: bool = False) -> torch.Tensor:
"""
Forward pass through the module.
Args:
image (torch.Tensor): Input image tensor.
embedding (torch.Tensor): Condition embedding tensor for grounder.
resolution (int): Resolution of the output tensor.
force_comb (bool): If True, force to get logits with all combination audio and image.
Returns:
torch.Tensor: Logits from the decoder.
"""
# N image, 1 embedding case -> [B_i, h, w]
if embedding.shape[0] != image.shape[0] and embedding.shape[0] == 1:
embeddings = embedding.repeat(image.shape[0], 1)
logits = self.forward_decoder(image, embeddings, resolution)
# N image, M embedding case -> [B_i, B_e, h, w]
elif embedding.shape[0] != image.shape[0] and embedding.shape[0] != 1 and image.shape[0] != 1 or force_comb:
logit_list = []
for i in range(embedding.shape[0]):
embeddings = embedding[i].unsqueeze(0).repeat(image.shape[0], 1)
logit_list.append(self.forward_decoder(image, embeddings, resolution))
logits = torch.cat(logit_list, dim=1)
# N image, N embedding or 1 image, N embedding -> [B_e, h, w]
else:
logits = self.forward_decoder(image, embedding, resolution)
return logits
def encode_masked_vision(self, image: torch.Tensor, embedding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, float, float]:
"""
Encode masked visual feature both image-level and feature-level.
Args:
image (torch.Tensor): Input image tensor.
embedding (torch.Tensor): Condition embedding tensor for grounder.
Returns:
tuple[torch.Tensor, torch.Tensor, float, float]: Feature masked embeddings, masked image embeddings, positive area, negative area.
"""
B, c, h, w = image.shape
maskclip_feat = self.av_grounder.get_pixels(image) # v^D: [B, c, h, w]
clipseg_mask = self.forward_module(image, embedding, h, force_comb=True) # M^G: [B, B, H, W]
# Area
area_matrix = self.masker_i(clipseg_mask).mean((2, 3))
positive_area = area_matrix.diagonal().mean()
negative_area = area_matrix.mean() - positive_area / B
# Feature level masker
feature_mask = F.interpolate(self.masker_f(clipseg_mask), maskclip_feat.shape[2])
# Image level masker
ind = torch.arange(B).to(image.device)
image_mask = self.masker_i(clipseg_mask[ind, ind].unsqueeze(1)) # Positive pair only
feature_masked_emb = torch.einsum('bchw,bnhw->bnc', maskclip_feat, feature_mask) / (feature_mask.sum() + 1e-6)
# step 1: forward the query images through the frozen CLIP vision encoder
masked_vision_outputs = self.av_grounder.clip.vision_model(pixel_values=image * image_mask,
output_attentions=None,
output_hidden_states=True,
return_dict=True)
masked_image_emb = self.av_grounder.clip.visual_projection(masked_vision_outputs[1])
return feature_masked_emb, masked_image_emb, positive_area, negative_area
def forward(self, image: torch.Tensor, embedding: torch.Tensor, resolution: int = 224) -> dict:
"""
Forward pass of ACL model.
Args:
image (torch.Tensor): Input image tensor.
embedding (torch.Tensor): Condition embedding tensor for grounder.
resolution (int): Resolution of the output tensor.
Returns:
dict: Output dictionary containing relevant tensors.
"""
if self.training:
# seg_logit = self.forward_module(image, embedding, resolution)
v_f, v_i, p_area, n_area = self.encode_masked_vision(image, embedding)
out_dict = {'v_f': v_f, 'v_i': v_i, 'p_area': p_area, 'n_area': n_area}
else:
seg_logit = self.forward_module(image, embedding, resolution)
heatmap = self.masker_i(seg_logit, infer=True)
out_dict = {'heatmap': heatmap}
return out_dict
def save(self, model_dir: str):
"""
Save model parameters to a file. (Only trainable parts)
Args:
model_dir (str): Directory to save the model.
"""
ckp = {'audio_proj': self.audio_proj.state_dict(), 'masker_i': self.masker_i.state_dict()}
torch.save(ckp, model_dir)
def load(self, model_dir: str):
"""
Load model parameters from a file. (Only trainable parts)
Args:
model_dir (str): Directory to load the model from.
"""
ckp = torch.load(model_dir, map_location=self.device)
self.audio_proj.load_state_dict(ckp['audio_proj'])
self.masker_i.load_state_dict(ckp['masker_i'])