|
import librosa |
|
import numpy as np |
|
import pandas as pd |
|
import soundfile as sf |
|
import torch |
|
from omegaconf import OmegaConf |
|
from pydub import AudioSegment |
|
from transformers import ( |
|
AutoFeatureExtractor, |
|
BertForSequenceClassification, |
|
BertJapaneseTokenizer, |
|
Wav2Vec2ForXVector, |
|
) |
|
|
|
|
|
class Search: |
|
def __init__(self, config): |
|
self.config = OmegaConf.load(config) |
|
self.df = pd.read_csv(self.config.path_csv)[["title", "url"]] |
|
self.audio_feature_extractor = AutoFeatureExtractor.from_pretrained( |
|
"anton-l/wav2vec2-base-superb-sv" |
|
) |
|
self.audio_model = Wav2Vec2ForXVector.from_pretrained( |
|
"anton-l/wav2vec2-base-superb-sv" |
|
) |
|
self.text_tokenizer = BertJapaneseTokenizer.from_pretrained( |
|
"cl-tohoku/bert-base-japanese-whole-word-masking" |
|
) |
|
self.text_model = BertForSequenceClassification.from_pretrained( |
|
"cl-tohoku/bert-base-japanese-whole-word-masking", |
|
num_labels=2, |
|
output_attentions=False, |
|
output_hidden_states=True, |
|
).eval() |
|
self.text_reference = torch.load(self.config.path_text_embedding) |
|
self.audio_reference = torch.load(self.config.path_audio_embedding) |
|
self.similarity = torch.nn.CosineSimilarity(dim=-1) |
|
|
|
def search(self, text, audio, ratio, topk): |
|
text_embed, audio_embed = self.get_embedding(text, audio) |
|
if text_embed is not None and audio_embed is not None: |
|
result = self.similarity( |
|
text_embed, self.text_reference |
|
) * ratio + self.similarity(audio_embed, self.audio_reference) * (1 - ratio) |
|
elif text_embed is not None: |
|
result = self.similarity(text_embed, self.text_reference) |
|
elif audio_embed is not None: |
|
result = self.similarity(audio_embed, self.audio_reference) |
|
else: |
|
raise ValueError("Input text or upload audio file.") |
|
rank = np.argsort(result.numpy())[::-1][0 : int(topk)] |
|
return self.df.iloc[rank] |
|
|
|
def get_embedding(self, text, audio): |
|
text_embed = None if text == "" else self._get_text_embedding(text) |
|
audio_embed = None if audio is None else self._get_audio_embedding(audio) |
|
return text_embed, audio_embed |
|
|
|
def _get_text_embedding(self, text): |
|
tokenized_text = self.text_tokenizer.tokenize(text) |
|
indexed_tokens = self.text_tokenizer.convert_tokens_to_ids(tokenized_text) |
|
tokens_tensor = torch.tensor([indexed_tokens]) |
|
with torch.no_grad(): |
|
all_encoder_layers = self.text_model(tokens_tensor) |
|
embedding = torch.mean(all_encoder_layers[1][-2][0], axis=0).reshape(1, -1) |
|
return embedding |
|
|
|
def _get_audio_embedding(self, audio): |
|
audio = self.preprocess_audio(audio) |
|
song = AudioSegment.from_wav(audio) |
|
song = np.array(song.get_array_of_samples(), dtype="float") |
|
inputs = self.audio_feature_extractor( |
|
[song], |
|
sampling_rate=self.config.sample_rate, |
|
return_tensors="pt", |
|
padding=True, |
|
) |
|
with torch.no_grad(): |
|
embedding = self.audio_model(**inputs).embeddings |
|
return embedding |
|
|
|
def preprocess_audio(self, audio): |
|
sample_rate, data = audio |
|
audio = "tmp.wav" |
|
sf.write(file=audio, data=data, samplerate=sample_rate) |
|
y, sr = librosa.core.load(audio, sr=self.config.sample_rate, mono=True) |
|
sf.write(audio, y, sr, subtype="PCM_16") |
|
return audio |
|
|