Sound-Effect-Search / src /create_embed.py
nomnomnonono's picture
initial
f41efe1
raw history blame
No virus
4.22 kB
import argparse
import os
import numpy as np
import pandas as pd
import torch
from omegaconf import OmegaConf
from pydub import AudioSegment
from tqdm import trange
from transformers import (
AutoFeatureExtractor,
BertForSequenceClassification,
BertJapaneseTokenizer,
Wav2Vec2ForXVector,
)
class Embeder:
def __init__(self, config):
self.config = OmegaConf.load(config)
self.df = pd.read_csv(config.path_csv)
self.audio_feature_extractor = AutoFeatureExtractor.from_pretrained(
"anton-l/wav2vec2-base-superb-sv"
)
self.audio_model = Wav2Vec2ForXVector.from_pretrained(
"anton-l/wav2vec2-base-superb-sv"
)
self.text_tokenizer = BertJapaneseTokenizer.from_pretrained(
"cl-tohoku/bert-base-japanese-whole-word-masking"
)
self.text_model = BertForSequenceClassification.from_pretrained(
"cl-tohoku/bert-base-japanese-whole-word-masking",
num_labels=2,
output_attentions=False,
output_hidden_states=True,
).eval()
def run(self):
self._create_audio_embed()
self._create_text_embed()
def _create_audio_embed(self):
audio_embed = None
idx = []
for i in trange(len(self.df)):
audio = []
song = AudioSegment.from_wav(
os.path.join(
self.config.path_data,
"new_" + self.df.iloc[i]["filename"].replace(".mp3", ".wav"),
)
)
song = np.array(song.get_array_of_samples(), dtype="float")
audio.append(song)
inputs = self.audio_feature_extractor(
audio,
sampling_rate=self.config.sample_rate,
return_tensors="pt",
padding=True,
)
try:
with torch.no_grad():
embeddings = self.audio_model(**inputs).embeddings
audio_embed = (
embeddings
if audio_embed is None
else torch.concatenate([audio_embed, embeddings])
)
except Exception:
idx.append(i)
audio_embed = torch.nn.functional.normalize(audio_embed, dim=-1).cpu()
self.clean_and_save_data(audio_embed, idx)
self.df = self.df.drop(index=idx)
self.df.to_csv(self.config.path_csv, index=False)
def _create_text_embed(self):
text_embed = None
for i in range(len(self.df)):
sentence = self.df.iloc[i]["filename"].replace(".mp3", "")
tokenized_text = self.text_tokenizer.tokenize(sentence)
indexed_tokens = self.text_tokenizer.convert_tokens_to_ids(tokenized_text)
tokens_tensor = torch.tensor([indexed_tokens])
with torch.no_grad():
all_encoder_layers = self.text_model(tokens_tensor)
embedding = torch.mean(all_encoder_layers[1][-2][0], axis=0).reshape(1, -1)
text_embed = (
embedding
if text_embed is None
else torch.concatenate([text_embed, embedding])
)
text_embed = torch.nn.functional.normalize(text_embed, dim=-1).cpu()
torch.save(text_embed, self.config.path_text_embedding)
def clean_and_save_data(self, audio_embed, idx):
clean_embed = None
for i in range(1, len(audio_embed)):
if i in idx:
continue
else:
clean_embed = (
audio_embed[i].reshape(1, -1)
if clean_embed is None
else torch.concatenate([clean_embed, audio_embed[i].reshape(1, -1)])
)
torch.save(clean_embed, self.config.path_audio_embedding)
def argparser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-c",
"--config",
type=str,
default="config.yaml",
help="File path for config file.",
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = argparser()
embeder = Embeder(args.config)
embeder.run()