karlopintaric commited on
Commit
dfd0008
1 Parent(s): d5b60a8

Upload 4 files

Browse files
Files changed (4) hide show
  1. .streamlit/config.toml +10 -0
  2. __init__.py +0 -0
  3. ui.py +97 -0
  4. ui_backend.py +254 -0
.streamlit/config.toml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [theme]
2
+ base = "dark"
3
+ primaryColor = "#FFFFFF"
4
+ backgroundColor = "#212121"
5
+ secondaryBackgroundColor = "#757575"
6
+ textColor = "#FFFFFF"
7
+ font = "sans serif"
8
+
9
+ [browser]
10
+ gatherUsageStats = false
__init__.py ADDED
File without changes
ui.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import streamlit as st
4
+ from ui_backend import (
5
+ check_for_api,
6
+ cut_audio_file,
7
+ display_predictions,
8
+ load_audio,
9
+ predict_multiple,
10
+ predict_single,
11
+ )
12
+
13
+
14
+ def main():
15
+ # Page settings
16
+ st.set_page_config(
17
+ page_title="Music Instrument Recognition", page_icon="🎸", layout="wide", initial_sidebar_state="collapsed"
18
+ )
19
+
20
+ # Sidebar
21
+ with st.sidebar:
22
+ st.title("⚙️ Settings")
23
+ selected_model = st.selectbox(
24
+ "Select Model",
25
+ ("Accuracy", "Speed"),
26
+ index=0,
27
+ help="Select a slower but more accurate model or a faster but less accurate model",
28
+ )
29
+
30
+ # Main title
31
+ st.markdown(
32
+ "<h1 style='text-align: center; color: #FFFFFF; font-size: 3rem;'>Instrument Recognition 🎶</h1>",
33
+ unsafe_allow_html=True,
34
+ )
35
+
36
+ # Upload widget
37
+ audio_file = load_audio()
38
+
39
+ # Send a health check request to the API in a loop until it is running
40
+ api_running = check_for_api(10)
41
+
42
+ # Enable or disable a button based on API status
43
+ predict_valid = False
44
+ cut_valid = False
45
+
46
+ if api_running:
47
+ st.info("API is running", icon="🤖")
48
+
49
+ if audio_file:
50
+ num_files = len(audio_file)
51
+ st.write(f"Number of uploaded files: {num_files}")
52
+ predict_valid = True
53
+ if len(audio_file) > 1:
54
+ cut_valid = False
55
+ else:
56
+ audio_file = audio_file[0]
57
+ cut_valid = True
58
+ name = audio_file.name
59
+
60
+ if cut_valid:
61
+ cut_audio = st.checkbox(
62
+ "✂️ Cut duration",
63
+ disabled=not predict_valid,
64
+ help="Cut a long audio file. Model works best if audio is around 15 seconds",
65
+ )
66
+
67
+ if cut_audio:
68
+ audio_file = cut_audio_file(audio_file, name)
69
+
70
+ result = st.button("Predict", disabled=not predict_valid, help="Send the audio to API to get a prediction")
71
+
72
+ if result:
73
+ predictions = {}
74
+ if isinstance(audio_file, list):
75
+ predictions = predict_multiple(audio_file, selected_model)
76
+
77
+ else:
78
+ predictions = predict_single(audio_file, name, selected_model)
79
+
80
+ # Sort the dictionary alphabetically by key
81
+ sorted_predictions = dict(sorted(predictions.items()))
82
+
83
+ # Convert the sorted dictionary to a JSON string
84
+ json_string = json.dumps(sorted_predictions)
85
+ st.download_button(
86
+ label="Download JSON",
87
+ file_name="predictions.json",
88
+ mime="application/json",
89
+ data=json_string,
90
+ help="Download the predictions in JSON format",
91
+ )
92
+
93
+ display_predictions(sorted_predictions)
94
+
95
+
96
+ if __name__ == "__main__":
97
+ main()
ui_backend.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import time
4
+ from json import JSONDecodeError
5
+ import math
6
+
7
+ import requests
8
+ import soundfile as sf
9
+ import streamlit as st
10
+
11
+ if os.environ.get("IS_DOCKER", False):
12
+ backend = "http://api:7860"
13
+ else:
14
+ backend = "http://0.0.0.0:7860"
15
+
16
+ INSTRUMENTS = {
17
+ "tru": "Trumpet",
18
+ "sax": "Saxophone",
19
+ "vio": "Violin",
20
+ "gac": "Acoustic Guitar",
21
+ "org": "Organ",
22
+ "cla": "Clarinet",
23
+ "flu": "Flute",
24
+ "voi": "Voice",
25
+ "gel": "Electric Guitar",
26
+ "cel": "Cello",
27
+ "pia": "Piano",
28
+ }
29
+
30
+
31
+ def load_audio():
32
+ """
33
+ Upload a WAV audio file and display it in a Streamlit app.
34
+
35
+ :return: A BytesIO object representing the uploaded audio file, or None if no file was uploaded.
36
+ :rtype: Optional[BytesIO]
37
+ """
38
+
39
+ audio_file = st.file_uploader(label="Upload audio file", type="wav", accept_multiple_files=True)
40
+ if len(audio_file) > 0:
41
+ st.audio(audio_file[0])
42
+ return audio_file
43
+ else:
44
+ return None
45
+
46
+
47
+ @st.cache_data(show_spinner=False)
48
+ def check_for_api(max_tries: int):
49
+ """
50
+ Check if the API is running by making a health check request.
51
+
52
+ :param max_tries: The maximum number of attempts to check the API's health.
53
+ :type max_tries: int
54
+ :return: True if the API is running, False otherwise.
55
+ :rtype: bool
56
+ """
57
+ trial_count = 0
58
+
59
+ with st.spinner("Waiting for API..."):
60
+ while trial_count <= max_tries:
61
+ try:
62
+ response = health_check()
63
+ if response:
64
+ return True
65
+ except requests.exceptions.ConnectionError:
66
+ trial_count += 1
67
+ # Handle connection error, e.g. API not yet running
68
+ time.sleep(5) # Sleep for 1 second before retrying
69
+ st.error("API is not running. Please refresh the page to try again.", icon="🚨")
70
+ st.stop()
71
+
72
+
73
+ def cut_audio_file(audio_file, name):
74
+ """
75
+ Cut an audio file and return the cut audio data as a tuple.
76
+
77
+ :param audio_file: The path of the audio file to be cut.
78
+ :type audio_file: str
79
+ :param name: The name of the audio file to be cut.
80
+ :type name: str
81
+ :raises RuntimeError: If the audio file cannot be read.
82
+ :return: A tuple containing the name and the cut audio data as a BytesIO object.
83
+ :rtype: tuple
84
+ """
85
+ try:
86
+ audio_data, sample_rate = sf.read(audio_file)
87
+ except RuntimeError as e:
88
+ raise e
89
+
90
+ # Display audio duration
91
+ duration = round(len(audio_data) / sample_rate, 2)
92
+ st.info(f"Audio Duration: {duration} seconds")
93
+
94
+ # Get start and end time for cutting
95
+ start_time = st.number_input("Start Time (seconds)", min_value=0.0, max_value=duration - 1, step=0.1)
96
+ end_time = st.number_input("End Time (seconds)", min_value=start_time, value=duration, max_value=duration, step=0.1)
97
+
98
+ # Convert start and end time to sample indices
99
+ start_sample = int(start_time * sample_rate)
100
+ end_sample = int(end_time * sample_rate)
101
+
102
+ # Cut audio
103
+ cut_audio_data = audio_data[start_sample:end_sample]
104
+
105
+ # Create a temporary in-memory file for cut audio
106
+ audio_file = io.BytesIO()
107
+ sf.write(audio_file, cut_audio_data, sample_rate, format="wav")
108
+
109
+ # Display cut audio
110
+ st.audio(audio_file, format="audio/wav")
111
+ audio_file = (name, audio_file)
112
+
113
+ return audio_file
114
+
115
+
116
+ def display_predictions(predictions: dict):
117
+ """
118
+ Display the predictions using instrument names instead of codes.
119
+
120
+ :param predictions: A dictionary containing the filenames and instruments detected in them.
121
+ :type predictions: dict
122
+ """
123
+
124
+ # Display the results using instrument names instead of codes
125
+ for filename, instruments in predictions.items():
126
+ st.subheader(filename)
127
+
128
+ if isinstance(instruments, str):
129
+ st.write(instruments)
130
+
131
+ else:
132
+ with st.container():
133
+ col1, col2 = st.columns([1, 3])
134
+ present_instruments = [
135
+ INSTRUMENTS[instrument_code] for instrument_code, presence in instruments.items() if presence
136
+ ]
137
+ if present_instruments:
138
+ for instrument_name in present_instruments:
139
+ with col1:
140
+ st.write(instrument_name)
141
+ with col2:
142
+ st.write("✔️")
143
+ else:
144
+ st.write("No instruments found in this file.")
145
+
146
+
147
+ def health_check():
148
+ """
149
+ Sends a health check request to the API and checks if it's running.
150
+
151
+ :return: Returns True if the API is running, else False.
152
+ :rtype: bool
153
+ """
154
+
155
+ # Send a health check request to the API
156
+ response = requests.get(f"{backend}/health-check", timeout=100)
157
+
158
+ # Check if the API is running
159
+ if response.status_code == 200:
160
+ return True
161
+ else:
162
+ return False
163
+
164
+
165
+ def predict(data, model_name):
166
+ """
167
+ Sends a POST request to the API with the provided data and model name.
168
+
169
+ :param data: The audio data to be used for prediction.
170
+ :type data: bytes
171
+ :param model_name: The name of the model to be used for prediction.
172
+ :type model_name: str
173
+ :return: The response from the API.
174
+ :rtype: requests.Response
175
+ """
176
+
177
+ file = {"file": data}
178
+ request_data = {"model_name": model_name}
179
+
180
+ response = requests.post(
181
+ f"{backend}/predict", params=request_data, files=file, timeout=300
182
+ ) # Replace with your API endpoint URL
183
+
184
+ return response
185
+
186
+
187
+ @st.cache_data(show_spinner=False)
188
+ def predict_single(audio_file, name, selected_model):
189
+ """
190
+ Predicts the instruments in a single audio file using the selected model.
191
+
192
+ :param audio_file: The audio file to be used for prediction.
193
+ :type audio_file: bytes
194
+ :param name: The name of the audio file.
195
+ :type name: str
196
+ :param selected_model: The name of the selected model.
197
+ :type selected_model: str
198
+ :return: A dictionary containing the predicted instruments for the audio file.
199
+ :rtype: dict
200
+ """
201
+
202
+ predictions = {}
203
+
204
+ with st.spinner("Predicting instruments..."):
205
+ response = predict(audio_file, selected_model)
206
+
207
+ if response.status_code == 200:
208
+ prediction = response.json()["prediction"]
209
+ predictions[name] = prediction.get(name, "Error making prediction")
210
+ else:
211
+ st.write(response)
212
+ try:
213
+ st.json(response.json())
214
+ except JSONDecodeError:
215
+ st.error(response.text)
216
+ st.stop()
217
+ return predictions
218
+
219
+
220
+ @st.cache_data(show_spinner=False)
221
+ def predict_multiple(audio_files, selected_model):
222
+ """
223
+ Generates predictions for multiple audio files using the selected model.
224
+
225
+ :param audio_files: A list of audio files to make predictions on.
226
+ :type audio_files: List[UploadedFile]
227
+ :param selected_model: The model to use for making predictions.
228
+ :type selected_model: str
229
+ :return: A dictionary where the keys are the names of the audio files and the values are the predicted labels.
230
+ :rtype: Dict[str, str]
231
+ """
232
+
233
+ predictions = {}
234
+ progress_text = "Getting predictions for all files. Please wait."
235
+ progress_bar = st.empty()
236
+ progress_bar.progress(0, text=progress_text)
237
+
238
+ num_files = len(audio_files)
239
+
240
+ for i, file in enumerate(audio_files):
241
+ name = file.name
242
+ response = predict(file, selected_model)
243
+ if response.status_code == 200:
244
+ prediction = response.json()["prediction"]
245
+ predictions[name] = prediction[name]
246
+ progress_bar.progress((i + 1) / num_files, text=progress_text)
247
+ else:
248
+ predictions[name] = "Error making prediction."
249
+ progress_bar.empty()
250
+ return predictions
251
+
252
+
253
+ if __name__ == "__main__":
254
+ pass