import torch from speechbrain.pretrained import Pretrained class CustomEncoderWav2vec2Classifier(Pretrained): """A ready-to-use class for utterance-level classification (e.g, speaker-id, language-id, emotion recognition, keyword spotting, etc). The class assumes that an self-supervised encoder like wav2vec2/hubert and a classifier model are defined in the yaml file. If you want to convert the predicted index into a corresponding text label, please provide the path of the label_encoder in a variable called 'lab_encoder_file' within the yaml. The class can be used either to run only the encoder (encode_batch()) to extract embeddings or to run a classification step (classify_batch()). ``` Example ------- >>> import torchaudio >>> from speechbrain.pretrained import EncoderClassifier >>> # Model is downloaded from the speechbrain HuggingFace repo >>> tmpdir = getfixture("tmpdir") >>> classifier = EncoderClassifier.from_hparams( ... source="speechbrain/spkrec-ecapa-voxceleb", ... savedir=tmpdir, ... ) >>> # Compute embeddings >>> signal, fs = torchaudio.load("samples/audio_samples/example1.wav") >>> embeddings = classifier.encode_batch(signal) >>> # Classification >>> prediction = classifier .classify_batch(signal) """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def encode_batch(self, wavs, wav_lens=None, normalize=False): """Encodes the input audio into a single vector embeddin g. The waveforms should already be in the model's desired format. You can call: ``normalized = .normalizer(signal, sample_rate)`` to get a correctly converted signal in most cases. Arguments --------- wavs : torch.tensor Batch of waveforms [batch, time, channels] or [batch, time] depending on the model. Make sure the sample rate is fs=16000 Hz. wav_lens : torch.tensor Lengths of the waveforms relative to the longest one in the batch, tensor of shape [batch]. The longest one should have relative length 1.0 and others len(waveform) / max_length. Used for ignoring padding. normalize : bool If True, it normalizes the embeddings with the statistics contained in mean_var_norm_emb. Returns ------- torch.tensor The encoded batch """ # Manage single waveforms in input if len(wavs.shape) == 1: wavs = wavs.unsqueeze(0) # Assign full length if wav_lens is not assigned if wav_lens is None: wav_lens = torch.ones(wavs.shape[0], device=self.device) # Storing waveform in the specified device wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) wavs = wavs.float() # Computing features and embeddings outputs = self.mods.wav2vec2(wavs) # last dim will be used for AdaptativeAVG pool outputs = self.mods.avg_pool(outputs, wav_lens) # # print(outputs.shape) outputs = outputs.view(outputs.shape[0], -1) # print(outputs.shape) return outputs def classify_batch(self, wavs, wav_lens=None): """Performs classification on the top of the encoded features. It returns the posterior probabilities, the index and, if the label encoder is specified it also the text label. Arguments --------- wavs : torch.tensor Batch of waveforms [batch, time, channels] or [batch, time] depending on the model. Make sure the sample rate is fs=16000 Hz. wav_lens : torch.tensor Lengths of the waveforms relative to the longest one in the batch, tensor of shape [batch]. The longest one should have relative length 1.0 and others len(waveform) / max_length. Used for ignoring padding. Returns ------- out_prob The log posterior probabilities of each class ([batch, N_class]) score: It is the value of the log-posterior for the best class ([batch,]) index The indexes of the best class ([batch,]) text_lab: List with the text labels corresponding to the indexes. (label encoder should be provided). """ outputs = self.encode_batch(wavs, wav_lens) #outputs = self.CH(wavs, wav_lens) outputs = self.mods.output_mlp(outputs) out_prob = self.hparams.softmax(outputs) score, index = torch.max(out_prob, dim=-1) text_lab = self.hparams.label_encoder.decode_torch(index) return out_prob, score, index, text_lab def CH(self, wavs, wav_lens=None): import torch import torch.nn.functional as F import soundfile as sf from fairseq import checkpoint_utils device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_path = "D:\pycharm2020\code\yuyin_ChineseWav2vec\pretrained_models\Chinses_hubert\\chinese-hubert-large-fairseq-ckpt.pt" wav_path = wavs def postprocess(feats, normalize=False): if feats.dim() == 2: feats = feats.mean(-1) assert feats.dim() == 1, feats.dim() if normalize: with torch.no_grad(): feats = F.layer_norm(feats, feats.shape) return feats print("loading model(s) from {}".format(model_path)) models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( [model_path], suffix="", ) print("loaded model(s) from {}".format(model_path)) print(f"normalize: {saved_cfg.task.normalize}") model = models[0] model = model.to(device) model = model.half() model.eval() # wav, sr = sf.read(wav_path) # feat = torch.from_numpy(wav_path).float() feat = postprocess(wav_path, normalize=saved_cfg.task.normalize) feats = feat.view(1, -1) padding_mask = ( torch.BoolTensor(feats.shape).fill_(False) ) inputs = { "source": feats.half().to(device), "padding_mask": padding_mask.to(device), } with torch.no_grad(): logits = model.extract_features(**inputs) outputs = self.mods.avg_pool(logits[0], wav_lens) # # print(outputs.shape) outputs = outputs.view(outputs.shape[0], -1) # print(outputs.shape) return outputs def classify_file(self, path): """Classifies the given audiofile into the given set of labels. Arguments --------- path : str Path to audio file to classify. Returns ------- out_prob The log posterior probabilities of each class ([batch, N_class]) score: It is the value of the log-posterior for the best class ([batch,]) index The indexes of the best class ([batch,]) text_lab: List with the text labels corresponding to the indexes. (label encoder should be provided). """ waveform = self.load_audio(path) # Fake a batch: batch = waveform.unsqueeze(0) rel_length = torch.tensor([1.0]) outputs = self.encode_batch(batch, rel_length) outputs = self.mods.output_mlp(outputs).squeeze(1) out_prob = self.hparams.softmax(outputs) score, index = torch.max(out_prob, dim=-1) text_lab = self.hparams.label_encoder.decode_torch(index) return out_prob, score, index, text_lab def forward(self, wavs, wav_lens=None, normalize=False): return self.encode_batch( wavs=wavs, wav_lens=wav_lens, normalize=normalize )