import random import torch import torch.nn as nn import torchaudio from models.CLAP.open_clip import create_model from import get_audio_features from transformers import RobertaTokenizer class CLAP_Encoder(nn.Module): def __init__( self, pretrained_path='checkpoint/', sampling_rate=32000, amodel = "HTSAT-base", ): super().__init__() self.device = "cpu" self.precision = "fp32" self.amodel = amodel # or 'PANN-14' self.tmodel = "roberta" # the best text encoder in our training self.enable_fusion = False # False if you do not want to use the fusion model self.fusion_type = "aff_2d" self.pretrained = pretrained_path self.sampling_rate = sampling_rate self.tokenize = RobertaTokenizer.from_pretrained("roberta-base") self.model, self.model_cfg = create_model( self.amodel, self.tmodel, self.pretrained, precision=self.precision, device=self.device, enable_fusion=self.enable_fusion, fusion_type=self.fusion_type, ) for p in self.model.parameters(): p.requires_grad = False self.model.eval() self.encoder_type = 'CLAP' def batch_to_list(self, batch): ret = [] for i in range(batch.size(0)): ret.append(batch[i]) return ret def _get_audio_embed(self, batch): # batch: [B, samples] with torch.no_grad(): audio_dict_list = [] assert ( self.sampling_rate == 32000 ), "We only support 32000 sampling rate" # batch: [bs, 1, t-samples] batch = torchaudio.functional.resample( batch, orig_freq=self.sampling_rate, new_freq=48000 ) for waveform in self.batch_to_list(batch): audio_dict = {} audio_dict = get_audio_features( audio_dict, waveform, 480000, data_truncating="fusion", data_filling="repeatpad", audio_cfg=self.model_cfg["audio_cfg"], ) audio_dict_list.append(audio_dict) # [bs, 512] embed = self.model.get_audio_embedding(audio_dict_list) return embed.detach() def _get_text_embed(self, batch): double_batch = False if len(batch) == 1: batch = batch * 2 double_batch = True with torch.no_grad(): # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode text_data = self.tokenizer(batch) embed = self.model.get_text_embedding(text_data) if double_batch: embed = embed[0].unsqueeze(0) return embed.detach() def get_query_embed(self, modality, audio=None, text=None, use_text_ratio=0.5, device=None): if modality == 'audio': embed = self._get_audio_embed(audio) elif modality == 'text': embed = self._get_text_embed(text) elif modality == 'hybird': if random.random() > use_text_ratio: embed = self._get_audio_embed(audio) else: embed = self._get_text_embed(text) else: raise NotImplementedError("Please check flag 'training_modality'.") return embed.float() def tokenizer(self, text): result = self.tokenize( text, padding="max_length", truncation=True, max_length=512, return_tensors="pt", ) return {k: v.squeeze(0) for k, v in result.items()}