nomnomnonono commited on
Commit
f41efe1
1 Parent(s): 690861d
Files changed (8) hide show
  1. .gitignore +4 -0
  2. README.md +14 -2
  3. app.py +40 -0
  4. config.yaml +5 -0
  5. requirements.txt +82 -0
  6. src/create_embed.py +125 -0
  7. src/scrape.py +138 -0
  8. src/search.py +88 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .DS_Store
2
+ __pycache__
3
+
4
+
README.md CHANGED
@@ -1,12 +1,24 @@
1
  ---
 
2
  title: Sound Effect Search
 
3
  emoji: 🏢
 
4
  colorFrom: gray
 
5
  colorTo: green
 
 
 
6
  sdk: gradio
7
- sdk_version: 3.28.0
 
 
8
  app_file: app.py
9
- pinned: false
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
  ---
2
+
3
  title: Sound Effect Search
4
+
5
  emoji: 🏢
6
+
7
  colorFrom: gray
8
+
9
  colorTo: green
10
+
11
+ python: 3.9.7
12
+
13
  sdk: gradio
14
+
15
+ sdk_version: 3.23.0
16
+
17
  app_file: app.py
18
+
19
+ pinned: true
20
+
21
  ---
22
 
23
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
24
+
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from src.search import Search
3
+
4
+ search = Search("config.yaml")
5
+
6
+ with gr.Blocks() as demo:
7
+ gr.Markdown("Search Sound Effect using this demo.")
8
+ with gr.TabItem("Search from Audio File"):
9
+ with gr.Row():
10
+ with gr.Column(scale=1):
11
+ text_input = gr.Textbox(value="太鼓", label="SE Title")
12
+ audio_input = gr.Audio(source="upload")
13
+ ratio = gr.Slider(minimum=0, maximum=1, value=1, label="Weight Parameter. 1 means 'use only text'. 0 means 'use only audio'.")
14
+ topk = gr.Dropdown(
15
+ [5, 10, 20, 30, 40, 50], value="20", label="Top K"
16
+ )
17
+ button = gr.Button("Search")
18
+ with gr.Column(scale=2):
19
+ output = gr.Dataframe()
20
+ with gr.TabItem("Search from Microphone"):
21
+ with gr.Row():
22
+ with gr.Column(scale=1):
23
+ mic_text_input = gr.Textbox(value="太鼓", label="SE Title")
24
+ mic_audio_input = gr.Audio(source="microphone")
25
+ mic_ratio = gr.Slider(minimum=0, maximum=1, value=1, label="Weight Parameter. 1 means 'use only text'. 0 means 'use only audio'.")
26
+ mic_topk = gr.Dropdown(
27
+ [5, 10, 20, 30, 40, 50], value="20", label="Top K"
28
+ )
29
+ mic_button = gr.Button("Search")
30
+ with gr.Column(scale=2):
31
+ mic_output = gr.Dataframe()
32
+
33
+ button.click(
34
+ search.search, inputs=[text_input, audio_input, ratio, topk], outputs=output
35
+ )
36
+ mic_button.click(
37
+ search.search, inputs=[mic_text_input, mic_audio_input, mic_ratio, mic_topk], outputs=mic_output
38
+ )
39
+
40
+ demo.launch()
config.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ path_data: data
2
+ path_csv: result.csv
3
+ path_text_embedding: text.pt
4
+ path_audio_embedding: audio.pt
5
+ sample_rate: 16000
requirements.txt ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.1.0
2
+ aiohttp==3.8.4
3
+ aiosignal==1.3.1
4
+ altair==4.2.2
5
+ antlr4-python3-runtime==4.9.3
6
+ anyio==3.6.2
7
+ async-timeout==4.0.2
8
+ attrs==23.1.0
9
+ certifi==2022.12.7
10
+ cffi==1.15.1
11
+ charset-normalizer==3.1.0
12
+ click==8.1.3
13
+ contourpy==1.0.7
14
+ cycler==0.11.0
15
+ entrypoints==0.4
16
+ fastapi==0.95.1
17
+ ffmpy==0.3.0
18
+ filelock==3.12.0
19
+ fonttools==4.39.3
20
+ frozenlist==1.3.3
21
+ fsspec==2023.4.0
22
+ fugashi==1.2.1
23
+ gradio==3.28.1
24
+ gradio_client==0.1.4
25
+ h11==0.14.0
26
+ httpcore==0.17.0
27
+ httpx==0.24.0
28
+ huggingface-hub==0.14.1
29
+ idna==3.4
30
+ importlib-resources==5.12.0
31
+ ipadic==1.0.0
32
+ Jinja2==3.1.2
33
+ jsonschema==4.17.3
34
+ kiwisolver==1.4.4
35
+ linkify-it-py==2.0.0
36
+ markdown-it-py==2.2.0
37
+ MarkupSafe==2.1.2
38
+ matplotlib==3.7.1
39
+ mdit-py-plugins==0.3.3
40
+ mdurl==0.1.2
41
+ mpmath==1.3.0
42
+ multidict==6.0.4
43
+ networkx==3.1
44
+ numpy==1.24.3
45
+ omegaconf==2.3.0
46
+ orjson==3.8.11
47
+ packaging==23.1
48
+ pandas==2.0.1
49
+ Pillow==9.5.0
50
+ pycparser==2.21
51
+ pydantic==1.10.7
52
+ pydub==0.25.1
53
+ pyparsing==3.0.9
54
+ pyrsistent==0.19.3
55
+ PySoundFile==0.9.0.post1
56
+ python-dateutil==2.8.2
57
+ python-multipart==0.0.6
58
+ pytz==2023.3
59
+ PyYAML==6.0
60
+ regex==2023.3.23
61
+ requests==2.29.0
62
+ semantic-version==2.10.0
63
+ six==1.16.0
64
+ sniffio==1.3.0
65
+ starlette==0.26.1
66
+ sympy==1.11.1
67
+ tokenizers==0.13.3
68
+ toolz==0.12.0
69
+ torch==2.0.0
70
+ torchaudio==2.0.1
71
+ torchvision==0.15.1
72
+ tqdm==4.65.0
73
+ transformers==4.28.1
74
+ typing_extensions==4.5.0
75
+ tzdata==2023.3
76
+ uc-micro-py==1.0.1
77
+ unidic-lite==1.0.8
78
+ urllib3==1.26.15
79
+ uvicorn==0.22.0
80
+ websockets==11.0.2
81
+ yarl==1.9.2
82
+ zipp==3.15.0
src/create_embed.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ from omegaconf import OmegaConf
8
+ from pydub import AudioSegment
9
+ from tqdm import trange
10
+ from transformers import (
11
+ AutoFeatureExtractor,
12
+ BertForSequenceClassification,
13
+ BertJapaneseTokenizer,
14
+ Wav2Vec2ForXVector,
15
+ )
16
+
17
+
18
+ class Embeder:
19
+ def __init__(self, config):
20
+ self.config = OmegaConf.load(config)
21
+ self.df = pd.read_csv(config.path_csv)
22
+ self.audio_feature_extractor = AutoFeatureExtractor.from_pretrained(
23
+ "anton-l/wav2vec2-base-superb-sv"
24
+ )
25
+ self.audio_model = Wav2Vec2ForXVector.from_pretrained(
26
+ "anton-l/wav2vec2-base-superb-sv"
27
+ )
28
+ self.text_tokenizer = BertJapaneseTokenizer.from_pretrained(
29
+ "cl-tohoku/bert-base-japanese-whole-word-masking"
30
+ )
31
+ self.text_model = BertForSequenceClassification.from_pretrained(
32
+ "cl-tohoku/bert-base-japanese-whole-word-masking",
33
+ num_labels=2,
34
+ output_attentions=False,
35
+ output_hidden_states=True,
36
+ ).eval()
37
+
38
+ def run(self):
39
+ self._create_audio_embed()
40
+ self._create_text_embed()
41
+
42
+ def _create_audio_embed(self):
43
+ audio_embed = None
44
+ idx = []
45
+ for i in trange(len(self.df)):
46
+ audio = []
47
+ song = AudioSegment.from_wav(
48
+ os.path.join(
49
+ self.config.path_data,
50
+ "new_" + self.df.iloc[i]["filename"].replace(".mp3", ".wav"),
51
+ )
52
+ )
53
+ song = np.array(song.get_array_of_samples(), dtype="float")
54
+ audio.append(song)
55
+ inputs = self.audio_feature_extractor(
56
+ audio,
57
+ sampling_rate=self.config.sample_rate,
58
+ return_tensors="pt",
59
+ padding=True,
60
+ )
61
+ try:
62
+ with torch.no_grad():
63
+ embeddings = self.audio_model(**inputs).embeddings
64
+ audio_embed = (
65
+ embeddings
66
+ if audio_embed is None
67
+ else torch.concatenate([audio_embed, embeddings])
68
+ )
69
+ except Exception:
70
+ idx.append(i)
71
+
72
+ audio_embed = torch.nn.functional.normalize(audio_embed, dim=-1).cpu()
73
+ self.clean_and_save_data(audio_embed, idx)
74
+ self.df = self.df.drop(index=idx)
75
+ self.df.to_csv(self.config.path_csv, index=False)
76
+
77
+ def _create_text_embed(self):
78
+ text_embed = None
79
+ for i in range(len(self.df)):
80
+ sentence = self.df.iloc[i]["filename"].replace(".mp3", "")
81
+ tokenized_text = self.text_tokenizer.tokenize(sentence)
82
+ indexed_tokens = self.text_tokenizer.convert_tokens_to_ids(tokenized_text)
83
+ tokens_tensor = torch.tensor([indexed_tokens])
84
+ with torch.no_grad():
85
+ all_encoder_layers = self.text_model(tokens_tensor)
86
+ embedding = torch.mean(all_encoder_layers[1][-2][0], axis=0).reshape(1, -1)
87
+ text_embed = (
88
+ embedding
89
+ if text_embed is None
90
+ else torch.concatenate([text_embed, embedding])
91
+ )
92
+ text_embed = torch.nn.functional.normalize(text_embed, dim=-1).cpu()
93
+ torch.save(text_embed, self.config.path_text_embedding)
94
+
95
+ def clean_and_save_data(self, audio_embed, idx):
96
+ clean_embed = None
97
+ for i in range(1, len(audio_embed)):
98
+ if i in idx:
99
+ continue
100
+ else:
101
+ clean_embed = (
102
+ audio_embed[i].reshape(1, -1)
103
+ if clean_embed is None
104
+ else torch.concatenate([clean_embed, audio_embed[i].reshape(1, -1)])
105
+ )
106
+ torch.save(clean_embed, self.config.path_audio_embedding)
107
+
108
+
109
+ def argparser():
110
+ parser = argparse.ArgumentParser()
111
+ parser.add_argument(
112
+ "-c",
113
+ "--config",
114
+ type=str,
115
+ default="config.yaml",
116
+ help="File path for config file.",
117
+ )
118
+ args = parser.parse_args()
119
+ return args
120
+
121
+
122
+ if __name__ == "__main__":
123
+ args = argparser()
124
+ embeder = Embeder(args.config)
125
+ embeder.run()
src/scrape.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import os
4
+ import time
5
+ import urllib
6
+
7
+ import librosa
8
+ import pandas as pd
9
+ import requests
10
+ import soundfile as sf
11
+ from bs4 import BeautifulSoup
12
+ from omegaconf import OmegaConf
13
+ from pydub import AudioSegment
14
+ from requests.exceptions import Timeout
15
+
16
+
17
+ class Scraper:
18
+ def __init__(self, config):
19
+ self.base_url = "https://soundeffect-lab.info/"
20
+ self.df = pd.DataFrame([], columns=["filename", "title", "category", "url"])
21
+ self.idx = 0
22
+ self.config = OmegaConf.load(config)
23
+ self.setup()
24
+ os.makedirs(self.config.path_data, exist_ok=True)
25
+ self.history = []
26
+
27
+ def run(self):
28
+ self.all_get()
29
+ self.preprocess()
30
+
31
+ def setup(self):
32
+ try:
33
+ html = requests.get(self.base_url, timeout=5)
34
+ except Timeout:
35
+ raise ValueError("Time Out")
36
+ soup = BeautifulSoup(html.content, "html.parser")
37
+ tags = soup.select("a")
38
+ self.urls = []
39
+ self.categories = []
40
+ for tag in tags:
41
+ category = tag.text
42
+ url = tag.get("href")
43
+ if "/sound/" in url:
44
+ self.urls.append(url)
45
+ self.categories.append(category)
46
+
47
+ def all_get(self):
48
+ for i in range(len(self.urls)):
49
+ now_url = self.base_url + self.urls[i][1:]
50
+ self.download(now_url, self.categories[i])
51
+ self.df.to_csv(self.config.path_csv)
52
+
53
+ def download(self, now_url, category):
54
+ try:
55
+ html = requests.get(now_url, timeout=5)
56
+ soup = BeautifulSoup(html.content, "html.parser")
57
+ body = soup.find(id="wrap").find("main")
58
+ tags = body.find(id="playarea").select("a")
59
+ count = 0
60
+ for tag in tags:
61
+ name = tag.get("download")
62
+ url = tag.get("href")
63
+ filename = os.path.join(self.config.path_data, name)
64
+ if os.path.exists(filename):
65
+ continue
66
+ try:
67
+ urllib.request.urlretrieve(now_url + url, filename)
68
+ title = name.replace(".mp3", "")
69
+ self.df.loc[self.idx] = {
70
+ "filename": filename,
71
+ "title": title,
72
+ "category": category,
73
+ "url": f"https://soundeffect-lab.info/sound/search.php?s={title}",
74
+ }
75
+ self.idx += 1
76
+ time.sleep(2)
77
+ count += 1
78
+ except Exception:
79
+ continue
80
+ self.history.append(category)
81
+ print(now_url, category, len(tags), count)
82
+ paths = glob.glob(os.path.join(self.config.path_data, "*"))
83
+ assert len(paths) == len(self.df)
84
+
85
+ others = body.find(id="pagemenu-top").select("a")
86
+ other_urls, other_categories = [], []
87
+ for other in others:
88
+ other_url = other.get("href")
89
+ other_name = other.find("img").get("alt")
90
+ if other_name in self.history:
91
+ continue
92
+ other_urls.append(other_url)
93
+ other_categories.append(other_name)
94
+ for i in range(len(other_urls)):
95
+ self.download(self.base_url + other_urls[i][1:], other_categories[i])
96
+ except Timeout:
97
+ print(f"Time Out: {now_url}")
98
+
99
+ def preprocess(self):
100
+ for i in range(len(self.df)):
101
+ song = AudioSegment.from_mp3(
102
+ os.path.join(self.config.path_data, self.df.iloc[i]["filename"])
103
+ )
104
+ song.export(
105
+ os.path.join(
106
+ self.config.path_data,
107
+ self.df.iloc[i]["filename"].replace(".mp3", ".wav"),
108
+ ),
109
+ format="wav",
110
+ )
111
+
112
+ for i in range(len(self.df)):
113
+ file = os.path.join(
114
+ self.config.path_data,
115
+ self.df.iloc[i]["filename"].replace(".mp3", ".wav"),
116
+ )
117
+ y, sr = librosa.core.load(file, sr=self.config.sample_rate, mono=True)
118
+ dir, name = os.path.split(file)
119
+ sf.write(os.path.join(dir, "new_" + name), y, sr, subtype="PCM_16")
120
+
121
+
122
+ def argparser():
123
+ parser = argparse.ArgumentParser()
124
+ parser.add_argument(
125
+ "-c",
126
+ "--config",
127
+ type=str,
128
+ default="config.yaml",
129
+ help="File path for config file.",
130
+ )
131
+ args = parser.parse_args()
132
+ return args
133
+
134
+
135
+ if __name__ == "__main__":
136
+ args = argparser()
137
+ scraper = Scraper(args.config)
138
+ scraper.run()
src/search.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import numpy as np
3
+ import pandas as pd
4
+ import soundfile as sf
5
+ import torch
6
+ from omegaconf import OmegaConf
7
+ from pydub import AudioSegment
8
+ from transformers import (
9
+ AutoFeatureExtractor,
10
+ BertForSequenceClassification,
11
+ BertJapaneseTokenizer,
12
+ Wav2Vec2ForXVector,
13
+ )
14
+
15
+
16
+ class Search:
17
+ def __init__(self, config):
18
+ self.config = OmegaConf.load(config)
19
+ self.df = pd.read_csv(self.config.path_csv)[["title", "url"]]
20
+ self.audio_feature_extractor = AutoFeatureExtractor.from_pretrained(
21
+ "anton-l/wav2vec2-base-superb-sv"
22
+ )
23
+ self.audio_model = Wav2Vec2ForXVector.from_pretrained(
24
+ "anton-l/wav2vec2-base-superb-sv"
25
+ )
26
+ self.text_tokenizer = BertJapaneseTokenizer.from_pretrained(
27
+ "cl-tohoku/bert-base-japanese-whole-word-masking"
28
+ )
29
+ self.text_model = BertForSequenceClassification.from_pretrained(
30
+ "cl-tohoku/bert-base-japanese-whole-word-masking",
31
+ num_labels=2,
32
+ output_attentions=False,
33
+ output_hidden_states=True,
34
+ ).eval()
35
+ self.text_reference = torch.load(self.config.path_text_embedding)
36
+ self.audio_reference = torch.load(self.config.path_audio_embedding)
37
+ self.similarity = torch.nn.CosineSimilarity(dim=-1)
38
+
39
+ def search(self, text, audio, ratio, topk):
40
+ text_embed, audio_embed = self.get_embedding(text, audio)
41
+ if text_embed is not None and audio_embed is not None:
42
+ result = self.similarity(
43
+ text_embed, self.text_reference
44
+ ) * ratio + self.similarity(audio_embed, self.audio_reference) * (1 - ratio)
45
+ elif text_embed is not None:
46
+ result = self.similarity(text_embed, self.text_reference)
47
+ elif audio_embed is not None:
48
+ result = self.similarity(audio_embed, self.audio_reference)
49
+ else:
50
+ raise ValueError("Input text or upload audio file.")
51
+ rank = np.argsort(result.numpy())[::-1][0 : int(topk)]
52
+ return self.df.iloc[rank]
53
+
54
+ def get_embedding(self, text, audio):
55
+ text_embed = None if text == "" else self._get_text_embedding(text)
56
+ audio_embed = None if audio is None else self._get_audio_embedding(audio)
57
+ return text_embed, audio_embed
58
+
59
+ def _get_text_embedding(self, text):
60
+ tokenized_text = self.text_tokenizer.tokenize(text)
61
+ indexed_tokens = self.text_tokenizer.convert_tokens_to_ids(tokenized_text)
62
+ tokens_tensor = torch.tensor([indexed_tokens])
63
+ with torch.no_grad():
64
+ all_encoder_layers = self.text_model(tokens_tensor)
65
+ embedding = torch.mean(all_encoder_layers[1][-2][0], axis=0).reshape(1, -1)
66
+ return embedding
67
+
68
+ def _get_audio_embedding(self, audio):
69
+ audio = self.preprocess_audio(audio)
70
+ song = AudioSegment.from_wav(audio)
71
+ song = np.array(song.get_array_of_samples(), dtype="float")
72
+ inputs = self.audio_feature_extractor(
73
+ [song],
74
+ sampling_rate=self.config.sample_rate,
75
+ return_tensors="pt",
76
+ padding=True,
77
+ )
78
+ with torch.no_grad():
79
+ embedding = self.audio_model(**inputs).embeddings
80
+ return embedding
81
+
82
+ def preprocess_audio(self, audio):
83
+ sample_rate, data = audio
84
+ audio = "tmp.wav"
85
+ sf.write(file=audio, data=data, samplerate=sample_rate)
86
+ y, sr = librosa.core.load(audio, sr=self.config.sample_rate, mono=True)
87
+ sf.write(audio, y, sr, subtype="PCM_16")
88
+ return audio