nomnomnonono
commited on
Commit
•
f41efe1
1
Parent(s):
690861d
initial
Browse files- .gitignore +4 -0
- README.md +14 -2
- app.py +40 -0
- config.yaml +5 -0
- requirements.txt +82 -0
- src/create_embed.py +125 -0
- src/scrape.py +138 -0
- src/search.py +88 -0
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.DS_Store
|
2 |
+
__pycache__
|
3 |
+
|
4 |
+
|
README.md
CHANGED
@@ -1,12 +1,24 @@
|
|
1 |
---
|
|
|
2 |
title: Sound Effect Search
|
|
|
3 |
emoji: 🏢
|
|
|
4 |
colorFrom: gray
|
|
|
5 |
colorTo: green
|
|
|
|
|
|
|
6 |
sdk: gradio
|
7 |
-
|
|
|
|
|
8 |
app_file: app.py
|
9 |
-
|
|
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
1 |
---
|
2 |
+
|
3 |
title: Sound Effect Search
|
4 |
+
|
5 |
emoji: 🏢
|
6 |
+
|
7 |
colorFrom: gray
|
8 |
+
|
9 |
colorTo: green
|
10 |
+
|
11 |
+
python: 3.9.7
|
12 |
+
|
13 |
sdk: gradio
|
14 |
+
|
15 |
+
sdk_version: 3.23.0
|
16 |
+
|
17 |
app_file: app.py
|
18 |
+
|
19 |
+
pinned: true
|
20 |
+
|
21 |
---
|
22 |
|
23 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
24 |
+
|
app.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from src.search import Search
|
3 |
+
|
4 |
+
search = Search("config.yaml")
|
5 |
+
|
6 |
+
with gr.Blocks() as demo:
|
7 |
+
gr.Markdown("Search Sound Effect using this demo.")
|
8 |
+
with gr.TabItem("Search from Audio File"):
|
9 |
+
with gr.Row():
|
10 |
+
with gr.Column(scale=1):
|
11 |
+
text_input = gr.Textbox(value="太鼓", label="SE Title")
|
12 |
+
audio_input = gr.Audio(source="upload")
|
13 |
+
ratio = gr.Slider(minimum=0, maximum=1, value=1, label="Weight Parameter. 1 means 'use only text'. 0 means 'use only audio'.")
|
14 |
+
topk = gr.Dropdown(
|
15 |
+
[5, 10, 20, 30, 40, 50], value="20", label="Top K"
|
16 |
+
)
|
17 |
+
button = gr.Button("Search")
|
18 |
+
with gr.Column(scale=2):
|
19 |
+
output = gr.Dataframe()
|
20 |
+
with gr.TabItem("Search from Microphone"):
|
21 |
+
with gr.Row():
|
22 |
+
with gr.Column(scale=1):
|
23 |
+
mic_text_input = gr.Textbox(value="太鼓", label="SE Title")
|
24 |
+
mic_audio_input = gr.Audio(source="microphone")
|
25 |
+
mic_ratio = gr.Slider(minimum=0, maximum=1, value=1, label="Weight Parameter. 1 means 'use only text'. 0 means 'use only audio'.")
|
26 |
+
mic_topk = gr.Dropdown(
|
27 |
+
[5, 10, 20, 30, 40, 50], value="20", label="Top K"
|
28 |
+
)
|
29 |
+
mic_button = gr.Button("Search")
|
30 |
+
with gr.Column(scale=2):
|
31 |
+
mic_output = gr.Dataframe()
|
32 |
+
|
33 |
+
button.click(
|
34 |
+
search.search, inputs=[text_input, audio_input, ratio, topk], outputs=output
|
35 |
+
)
|
36 |
+
mic_button.click(
|
37 |
+
search.search, inputs=[mic_text_input, mic_audio_input, mic_ratio, mic_topk], outputs=mic_output
|
38 |
+
)
|
39 |
+
|
40 |
+
demo.launch()
|
config.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
path_data: data
|
2 |
+
path_csv: result.csv
|
3 |
+
path_text_embedding: text.pt
|
4 |
+
path_audio_embedding: audio.pt
|
5 |
+
sample_rate: 16000
|
requirements.txt
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==23.1.0
|
2 |
+
aiohttp==3.8.4
|
3 |
+
aiosignal==1.3.1
|
4 |
+
altair==4.2.2
|
5 |
+
antlr4-python3-runtime==4.9.3
|
6 |
+
anyio==3.6.2
|
7 |
+
async-timeout==4.0.2
|
8 |
+
attrs==23.1.0
|
9 |
+
certifi==2022.12.7
|
10 |
+
cffi==1.15.1
|
11 |
+
charset-normalizer==3.1.0
|
12 |
+
click==8.1.3
|
13 |
+
contourpy==1.0.7
|
14 |
+
cycler==0.11.0
|
15 |
+
entrypoints==0.4
|
16 |
+
fastapi==0.95.1
|
17 |
+
ffmpy==0.3.0
|
18 |
+
filelock==3.12.0
|
19 |
+
fonttools==4.39.3
|
20 |
+
frozenlist==1.3.3
|
21 |
+
fsspec==2023.4.0
|
22 |
+
fugashi==1.2.1
|
23 |
+
gradio==3.28.1
|
24 |
+
gradio_client==0.1.4
|
25 |
+
h11==0.14.0
|
26 |
+
httpcore==0.17.0
|
27 |
+
httpx==0.24.0
|
28 |
+
huggingface-hub==0.14.1
|
29 |
+
idna==3.4
|
30 |
+
importlib-resources==5.12.0
|
31 |
+
ipadic==1.0.0
|
32 |
+
Jinja2==3.1.2
|
33 |
+
jsonschema==4.17.3
|
34 |
+
kiwisolver==1.4.4
|
35 |
+
linkify-it-py==2.0.0
|
36 |
+
markdown-it-py==2.2.0
|
37 |
+
MarkupSafe==2.1.2
|
38 |
+
matplotlib==3.7.1
|
39 |
+
mdit-py-plugins==0.3.3
|
40 |
+
mdurl==0.1.2
|
41 |
+
mpmath==1.3.0
|
42 |
+
multidict==6.0.4
|
43 |
+
networkx==3.1
|
44 |
+
numpy==1.24.3
|
45 |
+
omegaconf==2.3.0
|
46 |
+
orjson==3.8.11
|
47 |
+
packaging==23.1
|
48 |
+
pandas==2.0.1
|
49 |
+
Pillow==9.5.0
|
50 |
+
pycparser==2.21
|
51 |
+
pydantic==1.10.7
|
52 |
+
pydub==0.25.1
|
53 |
+
pyparsing==3.0.9
|
54 |
+
pyrsistent==0.19.3
|
55 |
+
PySoundFile==0.9.0.post1
|
56 |
+
python-dateutil==2.8.2
|
57 |
+
python-multipart==0.0.6
|
58 |
+
pytz==2023.3
|
59 |
+
PyYAML==6.0
|
60 |
+
regex==2023.3.23
|
61 |
+
requests==2.29.0
|
62 |
+
semantic-version==2.10.0
|
63 |
+
six==1.16.0
|
64 |
+
sniffio==1.3.0
|
65 |
+
starlette==0.26.1
|
66 |
+
sympy==1.11.1
|
67 |
+
tokenizers==0.13.3
|
68 |
+
toolz==0.12.0
|
69 |
+
torch==2.0.0
|
70 |
+
torchaudio==2.0.1
|
71 |
+
torchvision==0.15.1
|
72 |
+
tqdm==4.65.0
|
73 |
+
transformers==4.28.1
|
74 |
+
typing_extensions==4.5.0
|
75 |
+
tzdata==2023.3
|
76 |
+
uc-micro-py==1.0.1
|
77 |
+
unidic-lite==1.0.8
|
78 |
+
urllib3==1.26.15
|
79 |
+
uvicorn==0.22.0
|
80 |
+
websockets==11.0.2
|
81 |
+
yarl==1.9.2
|
82 |
+
zipp==3.15.0
|
src/create_embed.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
import torch
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from pydub import AudioSegment
|
9 |
+
from tqdm import trange
|
10 |
+
from transformers import (
|
11 |
+
AutoFeatureExtractor,
|
12 |
+
BertForSequenceClassification,
|
13 |
+
BertJapaneseTokenizer,
|
14 |
+
Wav2Vec2ForXVector,
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
class Embeder:
|
19 |
+
def __init__(self, config):
|
20 |
+
self.config = OmegaConf.load(config)
|
21 |
+
self.df = pd.read_csv(config.path_csv)
|
22 |
+
self.audio_feature_extractor = AutoFeatureExtractor.from_pretrained(
|
23 |
+
"anton-l/wav2vec2-base-superb-sv"
|
24 |
+
)
|
25 |
+
self.audio_model = Wav2Vec2ForXVector.from_pretrained(
|
26 |
+
"anton-l/wav2vec2-base-superb-sv"
|
27 |
+
)
|
28 |
+
self.text_tokenizer = BertJapaneseTokenizer.from_pretrained(
|
29 |
+
"cl-tohoku/bert-base-japanese-whole-word-masking"
|
30 |
+
)
|
31 |
+
self.text_model = BertForSequenceClassification.from_pretrained(
|
32 |
+
"cl-tohoku/bert-base-japanese-whole-word-masking",
|
33 |
+
num_labels=2,
|
34 |
+
output_attentions=False,
|
35 |
+
output_hidden_states=True,
|
36 |
+
).eval()
|
37 |
+
|
38 |
+
def run(self):
|
39 |
+
self._create_audio_embed()
|
40 |
+
self._create_text_embed()
|
41 |
+
|
42 |
+
def _create_audio_embed(self):
|
43 |
+
audio_embed = None
|
44 |
+
idx = []
|
45 |
+
for i in trange(len(self.df)):
|
46 |
+
audio = []
|
47 |
+
song = AudioSegment.from_wav(
|
48 |
+
os.path.join(
|
49 |
+
self.config.path_data,
|
50 |
+
"new_" + self.df.iloc[i]["filename"].replace(".mp3", ".wav"),
|
51 |
+
)
|
52 |
+
)
|
53 |
+
song = np.array(song.get_array_of_samples(), dtype="float")
|
54 |
+
audio.append(song)
|
55 |
+
inputs = self.audio_feature_extractor(
|
56 |
+
audio,
|
57 |
+
sampling_rate=self.config.sample_rate,
|
58 |
+
return_tensors="pt",
|
59 |
+
padding=True,
|
60 |
+
)
|
61 |
+
try:
|
62 |
+
with torch.no_grad():
|
63 |
+
embeddings = self.audio_model(**inputs).embeddings
|
64 |
+
audio_embed = (
|
65 |
+
embeddings
|
66 |
+
if audio_embed is None
|
67 |
+
else torch.concatenate([audio_embed, embeddings])
|
68 |
+
)
|
69 |
+
except Exception:
|
70 |
+
idx.append(i)
|
71 |
+
|
72 |
+
audio_embed = torch.nn.functional.normalize(audio_embed, dim=-1).cpu()
|
73 |
+
self.clean_and_save_data(audio_embed, idx)
|
74 |
+
self.df = self.df.drop(index=idx)
|
75 |
+
self.df.to_csv(self.config.path_csv, index=False)
|
76 |
+
|
77 |
+
def _create_text_embed(self):
|
78 |
+
text_embed = None
|
79 |
+
for i in range(len(self.df)):
|
80 |
+
sentence = self.df.iloc[i]["filename"].replace(".mp3", "")
|
81 |
+
tokenized_text = self.text_tokenizer.tokenize(sentence)
|
82 |
+
indexed_tokens = self.text_tokenizer.convert_tokens_to_ids(tokenized_text)
|
83 |
+
tokens_tensor = torch.tensor([indexed_tokens])
|
84 |
+
with torch.no_grad():
|
85 |
+
all_encoder_layers = self.text_model(tokens_tensor)
|
86 |
+
embedding = torch.mean(all_encoder_layers[1][-2][0], axis=0).reshape(1, -1)
|
87 |
+
text_embed = (
|
88 |
+
embedding
|
89 |
+
if text_embed is None
|
90 |
+
else torch.concatenate([text_embed, embedding])
|
91 |
+
)
|
92 |
+
text_embed = torch.nn.functional.normalize(text_embed, dim=-1).cpu()
|
93 |
+
torch.save(text_embed, self.config.path_text_embedding)
|
94 |
+
|
95 |
+
def clean_and_save_data(self, audio_embed, idx):
|
96 |
+
clean_embed = None
|
97 |
+
for i in range(1, len(audio_embed)):
|
98 |
+
if i in idx:
|
99 |
+
continue
|
100 |
+
else:
|
101 |
+
clean_embed = (
|
102 |
+
audio_embed[i].reshape(1, -1)
|
103 |
+
if clean_embed is None
|
104 |
+
else torch.concatenate([clean_embed, audio_embed[i].reshape(1, -1)])
|
105 |
+
)
|
106 |
+
torch.save(clean_embed, self.config.path_audio_embedding)
|
107 |
+
|
108 |
+
|
109 |
+
def argparser():
|
110 |
+
parser = argparse.ArgumentParser()
|
111 |
+
parser.add_argument(
|
112 |
+
"-c",
|
113 |
+
"--config",
|
114 |
+
type=str,
|
115 |
+
default="config.yaml",
|
116 |
+
help="File path for config file.",
|
117 |
+
)
|
118 |
+
args = parser.parse_args()
|
119 |
+
return args
|
120 |
+
|
121 |
+
|
122 |
+
if __name__ == "__main__":
|
123 |
+
args = argparser()
|
124 |
+
embeder = Embeder(args.config)
|
125 |
+
embeder.run()
|
src/scrape.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import glob
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
import urllib
|
6 |
+
|
7 |
+
import librosa
|
8 |
+
import pandas as pd
|
9 |
+
import requests
|
10 |
+
import soundfile as sf
|
11 |
+
from bs4 import BeautifulSoup
|
12 |
+
from omegaconf import OmegaConf
|
13 |
+
from pydub import AudioSegment
|
14 |
+
from requests.exceptions import Timeout
|
15 |
+
|
16 |
+
|
17 |
+
class Scraper:
|
18 |
+
def __init__(self, config):
|
19 |
+
self.base_url = "https://soundeffect-lab.info/"
|
20 |
+
self.df = pd.DataFrame([], columns=["filename", "title", "category", "url"])
|
21 |
+
self.idx = 0
|
22 |
+
self.config = OmegaConf.load(config)
|
23 |
+
self.setup()
|
24 |
+
os.makedirs(self.config.path_data, exist_ok=True)
|
25 |
+
self.history = []
|
26 |
+
|
27 |
+
def run(self):
|
28 |
+
self.all_get()
|
29 |
+
self.preprocess()
|
30 |
+
|
31 |
+
def setup(self):
|
32 |
+
try:
|
33 |
+
html = requests.get(self.base_url, timeout=5)
|
34 |
+
except Timeout:
|
35 |
+
raise ValueError("Time Out")
|
36 |
+
soup = BeautifulSoup(html.content, "html.parser")
|
37 |
+
tags = soup.select("a")
|
38 |
+
self.urls = []
|
39 |
+
self.categories = []
|
40 |
+
for tag in tags:
|
41 |
+
category = tag.text
|
42 |
+
url = tag.get("href")
|
43 |
+
if "/sound/" in url:
|
44 |
+
self.urls.append(url)
|
45 |
+
self.categories.append(category)
|
46 |
+
|
47 |
+
def all_get(self):
|
48 |
+
for i in range(len(self.urls)):
|
49 |
+
now_url = self.base_url + self.urls[i][1:]
|
50 |
+
self.download(now_url, self.categories[i])
|
51 |
+
self.df.to_csv(self.config.path_csv)
|
52 |
+
|
53 |
+
def download(self, now_url, category):
|
54 |
+
try:
|
55 |
+
html = requests.get(now_url, timeout=5)
|
56 |
+
soup = BeautifulSoup(html.content, "html.parser")
|
57 |
+
body = soup.find(id="wrap").find("main")
|
58 |
+
tags = body.find(id="playarea").select("a")
|
59 |
+
count = 0
|
60 |
+
for tag in tags:
|
61 |
+
name = tag.get("download")
|
62 |
+
url = tag.get("href")
|
63 |
+
filename = os.path.join(self.config.path_data, name)
|
64 |
+
if os.path.exists(filename):
|
65 |
+
continue
|
66 |
+
try:
|
67 |
+
urllib.request.urlretrieve(now_url + url, filename)
|
68 |
+
title = name.replace(".mp3", "")
|
69 |
+
self.df.loc[self.idx] = {
|
70 |
+
"filename": filename,
|
71 |
+
"title": title,
|
72 |
+
"category": category,
|
73 |
+
"url": f"https://soundeffect-lab.info/sound/search.php?s={title}",
|
74 |
+
}
|
75 |
+
self.idx += 1
|
76 |
+
time.sleep(2)
|
77 |
+
count += 1
|
78 |
+
except Exception:
|
79 |
+
continue
|
80 |
+
self.history.append(category)
|
81 |
+
print(now_url, category, len(tags), count)
|
82 |
+
paths = glob.glob(os.path.join(self.config.path_data, "*"))
|
83 |
+
assert len(paths) == len(self.df)
|
84 |
+
|
85 |
+
others = body.find(id="pagemenu-top").select("a")
|
86 |
+
other_urls, other_categories = [], []
|
87 |
+
for other in others:
|
88 |
+
other_url = other.get("href")
|
89 |
+
other_name = other.find("img").get("alt")
|
90 |
+
if other_name in self.history:
|
91 |
+
continue
|
92 |
+
other_urls.append(other_url)
|
93 |
+
other_categories.append(other_name)
|
94 |
+
for i in range(len(other_urls)):
|
95 |
+
self.download(self.base_url + other_urls[i][1:], other_categories[i])
|
96 |
+
except Timeout:
|
97 |
+
print(f"Time Out: {now_url}")
|
98 |
+
|
99 |
+
def preprocess(self):
|
100 |
+
for i in range(len(self.df)):
|
101 |
+
song = AudioSegment.from_mp3(
|
102 |
+
os.path.join(self.config.path_data, self.df.iloc[i]["filename"])
|
103 |
+
)
|
104 |
+
song.export(
|
105 |
+
os.path.join(
|
106 |
+
self.config.path_data,
|
107 |
+
self.df.iloc[i]["filename"].replace(".mp3", ".wav"),
|
108 |
+
),
|
109 |
+
format="wav",
|
110 |
+
)
|
111 |
+
|
112 |
+
for i in range(len(self.df)):
|
113 |
+
file = os.path.join(
|
114 |
+
self.config.path_data,
|
115 |
+
self.df.iloc[i]["filename"].replace(".mp3", ".wav"),
|
116 |
+
)
|
117 |
+
y, sr = librosa.core.load(file, sr=self.config.sample_rate, mono=True)
|
118 |
+
dir, name = os.path.split(file)
|
119 |
+
sf.write(os.path.join(dir, "new_" + name), y, sr, subtype="PCM_16")
|
120 |
+
|
121 |
+
|
122 |
+
def argparser():
|
123 |
+
parser = argparse.ArgumentParser()
|
124 |
+
parser.add_argument(
|
125 |
+
"-c",
|
126 |
+
"--config",
|
127 |
+
type=str,
|
128 |
+
default="config.yaml",
|
129 |
+
help="File path for config file.",
|
130 |
+
)
|
131 |
+
args = parser.parse_args()
|
132 |
+
return args
|
133 |
+
|
134 |
+
|
135 |
+
if __name__ == "__main__":
|
136 |
+
args = argparser()
|
137 |
+
scraper = Scraper(args.config)
|
138 |
+
scraper.run()
|
src/search.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
import soundfile as sf
|
5 |
+
import torch
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
from pydub import AudioSegment
|
8 |
+
from transformers import (
|
9 |
+
AutoFeatureExtractor,
|
10 |
+
BertForSequenceClassification,
|
11 |
+
BertJapaneseTokenizer,
|
12 |
+
Wav2Vec2ForXVector,
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
class Search:
|
17 |
+
def __init__(self, config):
|
18 |
+
self.config = OmegaConf.load(config)
|
19 |
+
self.df = pd.read_csv(self.config.path_csv)[["title", "url"]]
|
20 |
+
self.audio_feature_extractor = AutoFeatureExtractor.from_pretrained(
|
21 |
+
"anton-l/wav2vec2-base-superb-sv"
|
22 |
+
)
|
23 |
+
self.audio_model = Wav2Vec2ForXVector.from_pretrained(
|
24 |
+
"anton-l/wav2vec2-base-superb-sv"
|
25 |
+
)
|
26 |
+
self.text_tokenizer = BertJapaneseTokenizer.from_pretrained(
|
27 |
+
"cl-tohoku/bert-base-japanese-whole-word-masking"
|
28 |
+
)
|
29 |
+
self.text_model = BertForSequenceClassification.from_pretrained(
|
30 |
+
"cl-tohoku/bert-base-japanese-whole-word-masking",
|
31 |
+
num_labels=2,
|
32 |
+
output_attentions=False,
|
33 |
+
output_hidden_states=True,
|
34 |
+
).eval()
|
35 |
+
self.text_reference = torch.load(self.config.path_text_embedding)
|
36 |
+
self.audio_reference = torch.load(self.config.path_audio_embedding)
|
37 |
+
self.similarity = torch.nn.CosineSimilarity(dim=-1)
|
38 |
+
|
39 |
+
def search(self, text, audio, ratio, topk):
|
40 |
+
text_embed, audio_embed = self.get_embedding(text, audio)
|
41 |
+
if text_embed is not None and audio_embed is not None:
|
42 |
+
result = self.similarity(
|
43 |
+
text_embed, self.text_reference
|
44 |
+
) * ratio + self.similarity(audio_embed, self.audio_reference) * (1 - ratio)
|
45 |
+
elif text_embed is not None:
|
46 |
+
result = self.similarity(text_embed, self.text_reference)
|
47 |
+
elif audio_embed is not None:
|
48 |
+
result = self.similarity(audio_embed, self.audio_reference)
|
49 |
+
else:
|
50 |
+
raise ValueError("Input text or upload audio file.")
|
51 |
+
rank = np.argsort(result.numpy())[::-1][0 : int(topk)]
|
52 |
+
return self.df.iloc[rank]
|
53 |
+
|
54 |
+
def get_embedding(self, text, audio):
|
55 |
+
text_embed = None if text == "" else self._get_text_embedding(text)
|
56 |
+
audio_embed = None if audio is None else self._get_audio_embedding(audio)
|
57 |
+
return text_embed, audio_embed
|
58 |
+
|
59 |
+
def _get_text_embedding(self, text):
|
60 |
+
tokenized_text = self.text_tokenizer.tokenize(text)
|
61 |
+
indexed_tokens = self.text_tokenizer.convert_tokens_to_ids(tokenized_text)
|
62 |
+
tokens_tensor = torch.tensor([indexed_tokens])
|
63 |
+
with torch.no_grad():
|
64 |
+
all_encoder_layers = self.text_model(tokens_tensor)
|
65 |
+
embedding = torch.mean(all_encoder_layers[1][-2][0], axis=0).reshape(1, -1)
|
66 |
+
return embedding
|
67 |
+
|
68 |
+
def _get_audio_embedding(self, audio):
|
69 |
+
audio = self.preprocess_audio(audio)
|
70 |
+
song = AudioSegment.from_wav(audio)
|
71 |
+
song = np.array(song.get_array_of_samples(), dtype="float")
|
72 |
+
inputs = self.audio_feature_extractor(
|
73 |
+
[song],
|
74 |
+
sampling_rate=self.config.sample_rate,
|
75 |
+
return_tensors="pt",
|
76 |
+
padding=True,
|
77 |
+
)
|
78 |
+
with torch.no_grad():
|
79 |
+
embedding = self.audio_model(**inputs).embeddings
|
80 |
+
return embedding
|
81 |
+
|
82 |
+
def preprocess_audio(self, audio):
|
83 |
+
sample_rate, data = audio
|
84 |
+
audio = "tmp.wav"
|
85 |
+
sf.write(file=audio, data=data, samplerate=sample_rate)
|
86 |
+
y, sr = librosa.core.load(audio, sr=self.config.sample_rate, mono=True)
|
87 |
+
sf.write(audio, y, sr, subtype="PCM_16")
|
88 |
+
return audio
|