File size: 4,223 Bytes
f41efe1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import argparse
import os
import numpy as np
import pandas as pd
import torch
from omegaconf import OmegaConf
from pydub import AudioSegment
from tqdm import trange
from transformers import (
AutoFeatureExtractor,
BertForSequenceClassification,
BertJapaneseTokenizer,
Wav2Vec2ForXVector,
)
class Embeder:
def __init__(self, config):
self.config = OmegaConf.load(config)
self.df = pd.read_csv(config.path_csv)
self.audio_feature_extractor = AutoFeatureExtractor.from_pretrained(
"anton-l/wav2vec2-base-superb-sv"
)
self.audio_model = Wav2Vec2ForXVector.from_pretrained(
"anton-l/wav2vec2-base-superb-sv"
)
self.text_tokenizer = BertJapaneseTokenizer.from_pretrained(
"cl-tohoku/bert-base-japanese-whole-word-masking"
)
self.text_model = BertForSequenceClassification.from_pretrained(
"cl-tohoku/bert-base-japanese-whole-word-masking",
num_labels=2,
output_attentions=False,
output_hidden_states=True,
).eval()
def run(self):
self._create_audio_embed()
self._create_text_embed()
def _create_audio_embed(self):
audio_embed = None
idx = []
for i in trange(len(self.df)):
audio = []
song = AudioSegment.from_wav(
os.path.join(
self.config.path_data,
"new_" + self.df.iloc[i]["filename"].replace(".mp3", ".wav"),
)
)
song = np.array(song.get_array_of_samples(), dtype="float")
audio.append(song)
inputs = self.audio_feature_extractor(
audio,
sampling_rate=self.config.sample_rate,
return_tensors="pt",
padding=True,
)
try:
with torch.no_grad():
embeddings = self.audio_model(**inputs).embeddings
audio_embed = (
embeddings
if audio_embed is None
else torch.concatenate([audio_embed, embeddings])
)
except Exception:
idx.append(i)
audio_embed = torch.nn.functional.normalize(audio_embed, dim=-1).cpu()
self.clean_and_save_data(audio_embed, idx)
self.df = self.df.drop(index=idx)
self.df.to_csv(self.config.path_csv, index=False)
def _create_text_embed(self):
text_embed = None
for i in range(len(self.df)):
sentence = self.df.iloc[i]["filename"].replace(".mp3", "")
tokenized_text = self.text_tokenizer.tokenize(sentence)
indexed_tokens = self.text_tokenizer.convert_tokens_to_ids(tokenized_text)
tokens_tensor = torch.tensor([indexed_tokens])
with torch.no_grad():
all_encoder_layers = self.text_model(tokens_tensor)
embedding = torch.mean(all_encoder_layers[1][-2][0], axis=0).reshape(1, -1)
text_embed = (
embedding
if text_embed is None
else torch.concatenate([text_embed, embedding])
)
text_embed = torch.nn.functional.normalize(text_embed, dim=-1).cpu()
torch.save(text_embed, self.config.path_text_embedding)
def clean_and_save_data(self, audio_embed, idx):
clean_embed = None
for i in range(1, len(audio_embed)):
if i in idx:
continue
else:
clean_embed = (
audio_embed[i].reshape(1, -1)
if clean_embed is None
else torch.concatenate([clean_embed, audio_embed[i].reshape(1, -1)])
)
torch.save(clean_embed, self.config.path_audio_embedding)
def argparser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-c",
"--config",
type=str,
default="config.yaml",
help="File path for config file.",
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = argparser()
embeder = Embeder(args.config)
embeder.run()
|