import argparse import os import numpy as np import pandas as pd import torch from omegaconf import OmegaConf from pydub import AudioSegment from tqdm import trange from transformers import ( AutoFeatureExtractor, BertForSequenceClassification, BertJapaneseTokenizer, Wav2Vec2ForXVector, ) class Embeder: def __init__(self, config): self.config = OmegaConf.load(config) self.df = pd.read_csv(config.path_csv) self.audio_feature_extractor = AutoFeatureExtractor.from_pretrained( "anton-l/wav2vec2-base-superb-sv" ) self.audio_model = Wav2Vec2ForXVector.from_pretrained( "anton-l/wav2vec2-base-superb-sv" ) self.text_tokenizer = BertJapaneseTokenizer.from_pretrained( "cl-tohoku/bert-base-japanese-whole-word-masking" ) self.text_model = BertForSequenceClassification.from_pretrained( "cl-tohoku/bert-base-japanese-whole-word-masking", num_labels=2, output_attentions=False, output_hidden_states=True, ).eval() def run(self): self._create_audio_embed() self._create_text_embed() def _create_audio_embed(self): audio_embed = None idx = [] for i in trange(len(self.df)): audio = [] song = AudioSegment.from_wav( os.path.join( self.config.path_data, "new_" + self.df.iloc[i]["filename"].replace(".mp3", ".wav"), ) ) song = np.array(song.get_array_of_samples(), dtype="float") audio.append(song) inputs = self.audio_feature_extractor( audio, sampling_rate=self.config.sample_rate, return_tensors="pt", padding=True, ) try: with torch.no_grad(): embeddings = self.audio_model(**inputs).embeddings audio_embed = ( embeddings if audio_embed is None else torch.concatenate([audio_embed, embeddings]) ) except Exception: idx.append(i) audio_embed = torch.nn.functional.normalize(audio_embed, dim=-1).cpu() self.clean_and_save_data(audio_embed, idx) self.df = self.df.drop(index=idx) self.df.to_csv(self.config.path_csv, index=False) def _create_text_embed(self): text_embed = None for i in range(len(self.df)): sentence = self.df.iloc[i]["filename"].replace(".mp3", "") tokenized_text = self.text_tokenizer.tokenize(sentence) indexed_tokens = self.text_tokenizer.convert_tokens_to_ids(tokenized_text) tokens_tensor = torch.tensor([indexed_tokens]) with torch.no_grad(): all_encoder_layers = self.text_model(tokens_tensor) embedding = torch.mean(all_encoder_layers[1][-2][0], axis=0).reshape(1, -1) text_embed = ( embedding if text_embed is None else torch.concatenate([text_embed, embedding]) ) text_embed = torch.nn.functional.normalize(text_embed, dim=-1).cpu() torch.save(text_embed, self.config.path_text_embedding) def clean_and_save_data(self, audio_embed, idx): clean_embed = None for i in range(1, len(audio_embed)): if i in idx: continue else: clean_embed = ( audio_embed[i].reshape(1, -1) if clean_embed is None else torch.concatenate([clean_embed, audio_embed[i].reshape(1, -1)]) ) torch.save(clean_embed, self.config.path_audio_embedding) def argparser(): parser = argparse.ArgumentParser() parser.add_argument( "-c", "--config", type=str, default="config.yaml", help="File path for config file.", ) args = parser.parse_args() return args if __name__ == "__main__": args = argparser() embeder = Embeder(args.config) embeder.run()