Spaces:
Sleeping
Sleeping
File size: 12,471 Bytes
fa25a07 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 |
import random
import torchaudio
import collections
import re
import torch.nn.functional as F
import numpy as np
from transformers import AutoTokenizer
from wav_evaluation.models.utils import read_config_as_args
from wav_evaluation.models.clap import CLAP
import math
import torchaudio.transforms as T
import os
import torch
string_classes = (str, bytes)
class CLAPWrapper():
"""
A class for interfacing CLAP model.
"""
def __init__(self, model_fp,config_path, use_cuda=False):
self.np_str_obj_array_pattern = re.compile(r'[SaUO]')
self.file_path = os.path.realpath(__file__)
self.default_collate_err_msg_format = (
"default_collate: batch must contain tensors, numpy arrays, numbers, "
"dicts or lists; found {}")
with open(config_path,'r') as f:
self.config_as_str = f.read()
self.model_fp = model_fp
self.use_cuda = use_cuda
self.clap, self.tokenizer, self.args = self.load_clap()
def load_clap(self):
r"""Load CLAP model with args from config file"""
args = read_config_as_args(self.config_as_str, is_config_str=True)
if 'bert' in args.text_model:
self.token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
else:
self.token_keys = ['input_ids', 'attention_mask']
clap = CLAP(
audioenc_name=args.audioenc_name,
sample_rate=args.sampling_rate,
window_size=args.window_size,
hop_size=args.hop_size,
mel_bins=args.mel_bins,
fmin=args.fmin,
fmax=args.fmax,
classes_num=args.num_classes,
out_emb=args.out_emb,
text_model=args.text_model,
transformer_embed_dim=args.transformer_embed_dim,
d_proj=args.d_proj
)
# Load pretrained weights for model
model_state_dict = torch.load(self.model_fp, map_location=torch.device('cpu'))['model']
clap.load_state_dict(model_state_dict)
clap.eval() # set clap in eval mode
tokenizer = AutoTokenizer.from_pretrained(args.text_model)
if self.use_cuda and torch.cuda.is_available():
clap = clap.cuda()
return clap, tokenizer, args
def default_collate(self, batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
# array of string classes and object
if self.np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(
self.default_collate_err_msg_format.format(elem.dtype))
return self.default_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, collections.abc.Mapping):
return {key: self.default_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(self.default_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError(
'each element in list of batch should be of equal size')
transposed = zip(*batch)
return [self.default_collate(samples) for samples in transposed]
raise TypeError(self.default_collate_err_msg_format.format(elem_type))
def resample_and_duration(self,wav_sr,audio_duration,resample=False):
audio_time_series,sample_rate = wav_sr
resample_rate = self.args.sampling_rate
if resample:
resampler = T.Resample(sample_rate, resample_rate)
audio_time_series = resampler(audio_time_series)
audio_time_series = audio_time_series.reshape(-1)
# audio_time_series is shorter than predefined audio duration,
# so audio_time_series is extended
if audio_duration*sample_rate >= audio_time_series.shape[0]:
repeat_factor = int(np.ceil((audio_duration*sample_rate) /
audio_time_series.shape[0]))
# Repeat audio_time_series by repeat_factor to match audio_duration
audio_time_series = audio_time_series.repeat(repeat_factor)
# remove excess part of audio_time_series
audio_time_series = audio_time_series[0:audio_duration*sample_rate]
else:
# audio_time_series is longer than predefined audio duration,
# so audio_time_series is trimmed
start_index = random.randrange(
audio_time_series.shape[0] - audio_duration*sample_rate)
audio_time_series = audio_time_series[start_index:start_index +
audio_duration*sample_rate]
return torch.FloatTensor(audio_time_series)
def load_audio_into_tensor(self, audio_path, audio_duration, resample=False):
r"""Loads audio file and returns raw audio."""
# Randomly sample a segment of audio_duration from the clip or pad to match duration
audio_time_series, sample_rate = torchaudio.load(audio_path)
return self.resample_and_duration((audio_time_series, sample_rate),audio_duration,resample)
def preprocess_audio(self, audio_files, resample):
r"""Load list of audio files and return raw audio"""
audio_tensors = []
for audio_file in audio_files:
if isinstance(audio_file,str):
audio_tensor = self.load_audio_into_tensor(audio_file, self.args.duration, resample)
elif isinstance(audio_file,tuple):
audio_tensor = self.resample_and_duration(audio_file, self.args.duration, resample)
else:
raise TypeError(f"type of audiofile is {type(audio_file)},which is not supported")
audio_tensor = audio_tensor.reshape(
1, -1).cuda() if self.use_cuda and torch.cuda.is_available() else audio_tensor.reshape(1, -1)
audio_tensors.append(audio_tensor)
return self.default_collate(audio_tensors)
def preprocess_text(self, text_queries):
r"""Load list of class labels and return tokenized text"""
tokenized_texts = []
for ttext in text_queries:
tok = self.tokenizer.encode_plus(
text=ttext, add_special_tokens=True, max_length=self.args.text_len, padding="max_length", return_tensors="pt") # max_length=self.args.text_len, padding=True,
for key in self.token_keys:
tok[key] = tok[key].reshape(-1).cuda() if self.use_cuda and torch.cuda.is_available() else tok[key].reshape(-1)
tokenized_texts.append(tok)
return self.default_collate(tokenized_texts)
def get_text_embeddings(self, class_labels):
r"""Load list of class labels and return text embeddings"""
preprocessed_text = self.preprocess_text(class_labels)
text_embeddings = self._get_text_embeddings(preprocessed_text)
text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True)
return text_embeddings
def get_audio_embeddings(self, audio_files, resample):
r"""Load list of audio files and return a audio embeddings"""
preprocessed_audio = self.preprocess_audio(audio_files, resample)
audio_embeddings = self._get_audio_embeddings(preprocessed_audio)
audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True)
return audio_embeddings
def _get_text_embeddings(self, preprocessed_text):
r"""Load preprocessed text and return text embeddings"""
with torch.no_grad():
text_embeddings = self.clap.caption_encoder(preprocessed_text)
text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True)
return text_embeddings
def _get_audio_embeddings(self, preprocessed_audio):
r"""Load preprocessed audio and return a audio embeddings"""
with torch.no_grad():
preprocessed_audio = preprocessed_audio.reshape(
preprocessed_audio.shape[0], preprocessed_audio.shape[2])
#Append [0] the audio emebdding, [1] has output class probabilities
audio_embeddings = self.clap.audio_encoder(preprocessed_audio)[0]
audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True)
return audio_embeddings
def compute_similarity(self, audio_embeddings, text_embeddings,use_logit_scale = True):
r"""Compute similarity between text and audio embeddings"""
if use_logit_scale:
logit_scale = self.clap.logit_scale.exp()
similarity = logit_scale*text_embeddings @ audio_embeddings.T
else:
similarity = text_embeddings @ audio_embeddings.T
return similarity.T
def cal_clap_score(self,txt,audio_path):
text_embeddings = self.get_text_embeddings([txt])# 经过了norm的embedding
audio_embeddings = self.get_audio_embeddings([audio_path], resample=True)# 这一步比较耗时,读取音频并重采样到44100
score = self.compute_similarity(audio_embeddings, text_embeddings,use_logit_scale=False).squeeze().cpu().numpy()
return score
def _generic_batch_inference(self, func, *args):
r"""Process audio and/or text per batch"""
input_tmp = args[0]
batch_size = args[-1]
# args[0] has audio_files, args[1] has class_labels
inputs = [args[0], args[1]] if len(args) == 3 else [args[0]]
args0_len = len(args[0])
# compute text_embeddings once for all the audio_files batches
if len(inputs) == 2:
text_embeddings = self.get_text_embeddings(args[1])
inputs = [args[0], args[1], text_embeddings]
dataset_idx = 0
for _ in range(math.ceil(args0_len/batch_size)):
next_batch_idx = dataset_idx + batch_size
# batch size is bigger than available audio/text items
if next_batch_idx >= args0_len:
inputs[0] = input_tmp[dataset_idx:]
return func(*tuple(inputs))
else:
inputs[0] = input_tmp[dataset_idx:next_batch_idx]
yield func(*tuple(inputs))
dataset_idx = next_batch_idx
def get_audio_embeddings_per_batch(self, audio_files, batch_size):
r"""Load preprocessed audio and return a audio embeddings per batch"""
return self._generic_batch_inference(self.get_audio_embeddings, audio_files, batch_size)
def get_text_embeddings_per_batch(self, class_labels, batch_size):
r"""Load preprocessed text and return text embeddings per batch"""
return self._generic_batch_inference(self.get_text_embeddings, class_labels, batch_size)
def classify_audio_files_per_batch(self, audio_files, class_labels, batch_size):
r"""Compute classification probabilities for each audio recording in a batch and each class label"""
return self._generic_batch_inference(self.classify_audio_files, audio_files, class_labels, batch_size)
|