import librosa import numpy as np import pandas as pd import soundfile as sf import torch from omegaconf import OmegaConf from pydub import AudioSegment from transformers import ( AutoFeatureExtractor, BertForSequenceClassification, BertJapaneseTokenizer, Wav2Vec2ForXVector, ) class Search: def __init__(self, config): self.config = OmegaConf.load(config) self.df = pd.read_csv(self.config.path_csv)[["title", "url"]] self.audio_feature_extractor = AutoFeatureExtractor.from_pretrained( "anton-l/wav2vec2-base-superb-sv" ) self.audio_model = Wav2Vec2ForXVector.from_pretrained( "anton-l/wav2vec2-base-superb-sv" ) self.text_tokenizer = BertJapaneseTokenizer.from_pretrained( "cl-tohoku/bert-base-japanese-whole-word-masking" ) self.text_model = BertForSequenceClassification.from_pretrained( "cl-tohoku/bert-base-japanese-whole-word-masking", num_labels=2, output_attentions=False, output_hidden_states=True, ).eval() self.text_reference = torch.load(self.config.path_text_embedding) self.audio_reference = torch.load(self.config.path_audio_embedding) self.similarity = torch.nn.CosineSimilarity(dim=-1) def search(self, text, audio, ratio, topk): text_embed, audio_embed = self.get_embedding(text, audio) if text_embed is not None and audio_embed is not None: result = self.similarity( text_embed, self.text_reference ) * ratio + self.similarity(audio_embed, self.audio_reference) * (1 - ratio) elif text_embed is not None: result = self.similarity(text_embed, self.text_reference) elif audio_embed is not None: result = self.similarity(audio_embed, self.audio_reference) else: raise ValueError("Input text or upload audio file.") rank = np.argsort(result.numpy())[::-1][0 : int(topk)] return self.df.iloc[rank] def get_embedding(self, text, audio): text_embed = None if text == "" else self._get_text_embedding(text) audio_embed = None if audio is None else self._get_audio_embedding(audio) return text_embed, audio_embed def _get_text_embedding(self, text): tokenized_text = self.text_tokenizer.tokenize(text) indexed_tokens = self.text_tokenizer.convert_tokens_to_ids(tokenized_text) tokens_tensor = torch.tensor([indexed_tokens]) with torch.no_grad(): all_encoder_layers = self.text_model(tokens_tensor) embedding = torch.mean(all_encoder_layers[1][-2][0], axis=0).reshape(1, -1) return embedding def _get_audio_embedding(self, audio): audio = self.preprocess_audio(audio) song = AudioSegment.from_wav(audio) song = np.array(song.get_array_of_samples(), dtype="float") inputs = self.audio_feature_extractor( [song], sampling_rate=self.config.sample_rate, return_tensors="pt", padding=True, ) with torch.no_grad(): embedding = self.audio_model(**inputs).embeddings return embedding def preprocess_audio(self, audio): sample_rate, data = audio audio = "tmp.wav" sf.write(file=audio, data=data, samplerate=sample_rate) y, sr = librosa.core.load(audio, sr=self.config.sample_rate, mono=True) sf.write(audio, y, sr, subtype="PCM_16") return audio