Spaces:
Runtime error
Runtime error
Commit
·
dfd0008
1
Parent(s):
d5b60a8
Upload 4 files
Browse files- .streamlit/config.toml +10 -0
- __init__.py +0 -0
- ui.py +97 -0
- ui_backend.py +254 -0
.streamlit/config.toml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[theme]
|
2 |
+
base = "dark"
|
3 |
+
primaryColor = "#FFFFFF"
|
4 |
+
backgroundColor = "#212121"
|
5 |
+
secondaryBackgroundColor = "#757575"
|
6 |
+
textColor = "#FFFFFF"
|
7 |
+
font = "sans serif"
|
8 |
+
|
9 |
+
[browser]
|
10 |
+
gatherUsageStats = false
|
__init__.py
ADDED
File without changes
|
ui.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import streamlit as st
|
4 |
+
from ui_backend import (
|
5 |
+
check_for_api,
|
6 |
+
cut_audio_file,
|
7 |
+
display_predictions,
|
8 |
+
load_audio,
|
9 |
+
predict_multiple,
|
10 |
+
predict_single,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
def main():
|
15 |
+
# Page settings
|
16 |
+
st.set_page_config(
|
17 |
+
page_title="Music Instrument Recognition", page_icon="🎸", layout="wide", initial_sidebar_state="collapsed"
|
18 |
+
)
|
19 |
+
|
20 |
+
# Sidebar
|
21 |
+
with st.sidebar:
|
22 |
+
st.title("⚙️ Settings")
|
23 |
+
selected_model = st.selectbox(
|
24 |
+
"Select Model",
|
25 |
+
("Accuracy", "Speed"),
|
26 |
+
index=0,
|
27 |
+
help="Select a slower but more accurate model or a faster but less accurate model",
|
28 |
+
)
|
29 |
+
|
30 |
+
# Main title
|
31 |
+
st.markdown(
|
32 |
+
"<h1 style='text-align: center; color: #FFFFFF; font-size: 3rem;'>Instrument Recognition 🎶</h1>",
|
33 |
+
unsafe_allow_html=True,
|
34 |
+
)
|
35 |
+
|
36 |
+
# Upload widget
|
37 |
+
audio_file = load_audio()
|
38 |
+
|
39 |
+
# Send a health check request to the API in a loop until it is running
|
40 |
+
api_running = check_for_api(10)
|
41 |
+
|
42 |
+
# Enable or disable a button based on API status
|
43 |
+
predict_valid = False
|
44 |
+
cut_valid = False
|
45 |
+
|
46 |
+
if api_running:
|
47 |
+
st.info("API is running", icon="🤖")
|
48 |
+
|
49 |
+
if audio_file:
|
50 |
+
num_files = len(audio_file)
|
51 |
+
st.write(f"Number of uploaded files: {num_files}")
|
52 |
+
predict_valid = True
|
53 |
+
if len(audio_file) > 1:
|
54 |
+
cut_valid = False
|
55 |
+
else:
|
56 |
+
audio_file = audio_file[0]
|
57 |
+
cut_valid = True
|
58 |
+
name = audio_file.name
|
59 |
+
|
60 |
+
if cut_valid:
|
61 |
+
cut_audio = st.checkbox(
|
62 |
+
"✂️ Cut duration",
|
63 |
+
disabled=not predict_valid,
|
64 |
+
help="Cut a long audio file. Model works best if audio is around 15 seconds",
|
65 |
+
)
|
66 |
+
|
67 |
+
if cut_audio:
|
68 |
+
audio_file = cut_audio_file(audio_file, name)
|
69 |
+
|
70 |
+
result = st.button("Predict", disabled=not predict_valid, help="Send the audio to API to get a prediction")
|
71 |
+
|
72 |
+
if result:
|
73 |
+
predictions = {}
|
74 |
+
if isinstance(audio_file, list):
|
75 |
+
predictions = predict_multiple(audio_file, selected_model)
|
76 |
+
|
77 |
+
else:
|
78 |
+
predictions = predict_single(audio_file, name, selected_model)
|
79 |
+
|
80 |
+
# Sort the dictionary alphabetically by key
|
81 |
+
sorted_predictions = dict(sorted(predictions.items()))
|
82 |
+
|
83 |
+
# Convert the sorted dictionary to a JSON string
|
84 |
+
json_string = json.dumps(sorted_predictions)
|
85 |
+
st.download_button(
|
86 |
+
label="Download JSON",
|
87 |
+
file_name="predictions.json",
|
88 |
+
mime="application/json",
|
89 |
+
data=json_string,
|
90 |
+
help="Download the predictions in JSON format",
|
91 |
+
)
|
92 |
+
|
93 |
+
display_predictions(sorted_predictions)
|
94 |
+
|
95 |
+
|
96 |
+
if __name__ == "__main__":
|
97 |
+
main()
|
ui_backend.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from json import JSONDecodeError
|
5 |
+
import math
|
6 |
+
|
7 |
+
import requests
|
8 |
+
import soundfile as sf
|
9 |
+
import streamlit as st
|
10 |
+
|
11 |
+
if os.environ.get("IS_DOCKER", False):
|
12 |
+
backend = "http://api:7860"
|
13 |
+
else:
|
14 |
+
backend = "http://0.0.0.0:7860"
|
15 |
+
|
16 |
+
INSTRUMENTS = {
|
17 |
+
"tru": "Trumpet",
|
18 |
+
"sax": "Saxophone",
|
19 |
+
"vio": "Violin",
|
20 |
+
"gac": "Acoustic Guitar",
|
21 |
+
"org": "Organ",
|
22 |
+
"cla": "Clarinet",
|
23 |
+
"flu": "Flute",
|
24 |
+
"voi": "Voice",
|
25 |
+
"gel": "Electric Guitar",
|
26 |
+
"cel": "Cello",
|
27 |
+
"pia": "Piano",
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
def load_audio():
|
32 |
+
"""
|
33 |
+
Upload a WAV audio file and display it in a Streamlit app.
|
34 |
+
|
35 |
+
:return: A BytesIO object representing the uploaded audio file, or None if no file was uploaded.
|
36 |
+
:rtype: Optional[BytesIO]
|
37 |
+
"""
|
38 |
+
|
39 |
+
audio_file = st.file_uploader(label="Upload audio file", type="wav", accept_multiple_files=True)
|
40 |
+
if len(audio_file) > 0:
|
41 |
+
st.audio(audio_file[0])
|
42 |
+
return audio_file
|
43 |
+
else:
|
44 |
+
return None
|
45 |
+
|
46 |
+
|
47 |
+
@st.cache_data(show_spinner=False)
|
48 |
+
def check_for_api(max_tries: int):
|
49 |
+
"""
|
50 |
+
Check if the API is running by making a health check request.
|
51 |
+
|
52 |
+
:param max_tries: The maximum number of attempts to check the API's health.
|
53 |
+
:type max_tries: int
|
54 |
+
:return: True if the API is running, False otherwise.
|
55 |
+
:rtype: bool
|
56 |
+
"""
|
57 |
+
trial_count = 0
|
58 |
+
|
59 |
+
with st.spinner("Waiting for API..."):
|
60 |
+
while trial_count <= max_tries:
|
61 |
+
try:
|
62 |
+
response = health_check()
|
63 |
+
if response:
|
64 |
+
return True
|
65 |
+
except requests.exceptions.ConnectionError:
|
66 |
+
trial_count += 1
|
67 |
+
# Handle connection error, e.g. API not yet running
|
68 |
+
time.sleep(5) # Sleep for 1 second before retrying
|
69 |
+
st.error("API is not running. Please refresh the page to try again.", icon="🚨")
|
70 |
+
st.stop()
|
71 |
+
|
72 |
+
|
73 |
+
def cut_audio_file(audio_file, name):
|
74 |
+
"""
|
75 |
+
Cut an audio file and return the cut audio data as a tuple.
|
76 |
+
|
77 |
+
:param audio_file: The path of the audio file to be cut.
|
78 |
+
:type audio_file: str
|
79 |
+
:param name: The name of the audio file to be cut.
|
80 |
+
:type name: str
|
81 |
+
:raises RuntimeError: If the audio file cannot be read.
|
82 |
+
:return: A tuple containing the name and the cut audio data as a BytesIO object.
|
83 |
+
:rtype: tuple
|
84 |
+
"""
|
85 |
+
try:
|
86 |
+
audio_data, sample_rate = sf.read(audio_file)
|
87 |
+
except RuntimeError as e:
|
88 |
+
raise e
|
89 |
+
|
90 |
+
# Display audio duration
|
91 |
+
duration = round(len(audio_data) / sample_rate, 2)
|
92 |
+
st.info(f"Audio Duration: {duration} seconds")
|
93 |
+
|
94 |
+
# Get start and end time for cutting
|
95 |
+
start_time = st.number_input("Start Time (seconds)", min_value=0.0, max_value=duration - 1, step=0.1)
|
96 |
+
end_time = st.number_input("End Time (seconds)", min_value=start_time, value=duration, max_value=duration, step=0.1)
|
97 |
+
|
98 |
+
# Convert start and end time to sample indices
|
99 |
+
start_sample = int(start_time * sample_rate)
|
100 |
+
end_sample = int(end_time * sample_rate)
|
101 |
+
|
102 |
+
# Cut audio
|
103 |
+
cut_audio_data = audio_data[start_sample:end_sample]
|
104 |
+
|
105 |
+
# Create a temporary in-memory file for cut audio
|
106 |
+
audio_file = io.BytesIO()
|
107 |
+
sf.write(audio_file, cut_audio_data, sample_rate, format="wav")
|
108 |
+
|
109 |
+
# Display cut audio
|
110 |
+
st.audio(audio_file, format="audio/wav")
|
111 |
+
audio_file = (name, audio_file)
|
112 |
+
|
113 |
+
return audio_file
|
114 |
+
|
115 |
+
|
116 |
+
def display_predictions(predictions: dict):
|
117 |
+
"""
|
118 |
+
Display the predictions using instrument names instead of codes.
|
119 |
+
|
120 |
+
:param predictions: A dictionary containing the filenames and instruments detected in them.
|
121 |
+
:type predictions: dict
|
122 |
+
"""
|
123 |
+
|
124 |
+
# Display the results using instrument names instead of codes
|
125 |
+
for filename, instruments in predictions.items():
|
126 |
+
st.subheader(filename)
|
127 |
+
|
128 |
+
if isinstance(instruments, str):
|
129 |
+
st.write(instruments)
|
130 |
+
|
131 |
+
else:
|
132 |
+
with st.container():
|
133 |
+
col1, col2 = st.columns([1, 3])
|
134 |
+
present_instruments = [
|
135 |
+
INSTRUMENTS[instrument_code] for instrument_code, presence in instruments.items() if presence
|
136 |
+
]
|
137 |
+
if present_instruments:
|
138 |
+
for instrument_name in present_instruments:
|
139 |
+
with col1:
|
140 |
+
st.write(instrument_name)
|
141 |
+
with col2:
|
142 |
+
st.write("✔️")
|
143 |
+
else:
|
144 |
+
st.write("No instruments found in this file.")
|
145 |
+
|
146 |
+
|
147 |
+
def health_check():
|
148 |
+
"""
|
149 |
+
Sends a health check request to the API and checks if it's running.
|
150 |
+
|
151 |
+
:return: Returns True if the API is running, else False.
|
152 |
+
:rtype: bool
|
153 |
+
"""
|
154 |
+
|
155 |
+
# Send a health check request to the API
|
156 |
+
response = requests.get(f"{backend}/health-check", timeout=100)
|
157 |
+
|
158 |
+
# Check if the API is running
|
159 |
+
if response.status_code == 200:
|
160 |
+
return True
|
161 |
+
else:
|
162 |
+
return False
|
163 |
+
|
164 |
+
|
165 |
+
def predict(data, model_name):
|
166 |
+
"""
|
167 |
+
Sends a POST request to the API with the provided data and model name.
|
168 |
+
|
169 |
+
:param data: The audio data to be used for prediction.
|
170 |
+
:type data: bytes
|
171 |
+
:param model_name: The name of the model to be used for prediction.
|
172 |
+
:type model_name: str
|
173 |
+
:return: The response from the API.
|
174 |
+
:rtype: requests.Response
|
175 |
+
"""
|
176 |
+
|
177 |
+
file = {"file": data}
|
178 |
+
request_data = {"model_name": model_name}
|
179 |
+
|
180 |
+
response = requests.post(
|
181 |
+
f"{backend}/predict", params=request_data, files=file, timeout=300
|
182 |
+
) # Replace with your API endpoint URL
|
183 |
+
|
184 |
+
return response
|
185 |
+
|
186 |
+
|
187 |
+
@st.cache_data(show_spinner=False)
|
188 |
+
def predict_single(audio_file, name, selected_model):
|
189 |
+
"""
|
190 |
+
Predicts the instruments in a single audio file using the selected model.
|
191 |
+
|
192 |
+
:param audio_file: The audio file to be used for prediction.
|
193 |
+
:type audio_file: bytes
|
194 |
+
:param name: The name of the audio file.
|
195 |
+
:type name: str
|
196 |
+
:param selected_model: The name of the selected model.
|
197 |
+
:type selected_model: str
|
198 |
+
:return: A dictionary containing the predicted instruments for the audio file.
|
199 |
+
:rtype: dict
|
200 |
+
"""
|
201 |
+
|
202 |
+
predictions = {}
|
203 |
+
|
204 |
+
with st.spinner("Predicting instruments..."):
|
205 |
+
response = predict(audio_file, selected_model)
|
206 |
+
|
207 |
+
if response.status_code == 200:
|
208 |
+
prediction = response.json()["prediction"]
|
209 |
+
predictions[name] = prediction.get(name, "Error making prediction")
|
210 |
+
else:
|
211 |
+
st.write(response)
|
212 |
+
try:
|
213 |
+
st.json(response.json())
|
214 |
+
except JSONDecodeError:
|
215 |
+
st.error(response.text)
|
216 |
+
st.stop()
|
217 |
+
return predictions
|
218 |
+
|
219 |
+
|
220 |
+
@st.cache_data(show_spinner=False)
|
221 |
+
def predict_multiple(audio_files, selected_model):
|
222 |
+
"""
|
223 |
+
Generates predictions for multiple audio files using the selected model.
|
224 |
+
|
225 |
+
:param audio_files: A list of audio files to make predictions on.
|
226 |
+
:type audio_files: List[UploadedFile]
|
227 |
+
:param selected_model: The model to use for making predictions.
|
228 |
+
:type selected_model: str
|
229 |
+
:return: A dictionary where the keys are the names of the audio files and the values are the predicted labels.
|
230 |
+
:rtype: Dict[str, str]
|
231 |
+
"""
|
232 |
+
|
233 |
+
predictions = {}
|
234 |
+
progress_text = "Getting predictions for all files. Please wait."
|
235 |
+
progress_bar = st.empty()
|
236 |
+
progress_bar.progress(0, text=progress_text)
|
237 |
+
|
238 |
+
num_files = len(audio_files)
|
239 |
+
|
240 |
+
for i, file in enumerate(audio_files):
|
241 |
+
name = file.name
|
242 |
+
response = predict(file, selected_model)
|
243 |
+
if response.status_code == 200:
|
244 |
+
prediction = response.json()["prediction"]
|
245 |
+
predictions[name] = prediction[name]
|
246 |
+
progress_bar.progress((i + 1) / num_files, text=progress_text)
|
247 |
+
else:
|
248 |
+
predictions[name] = "Error making prediction."
|
249 |
+
progress_bar.empty()
|
250 |
+
return predictions
|
251 |
+
|
252 |
+
|
253 |
+
if __name__ == "__main__":
|
254 |
+
pass
|