karlopintaric commited on
Commit
61f246b
1 Parent(s): 1fa4e61

Delete src/frontend

Browse files
src/frontend/.streamlit/config.toml DELETED
@@ -1,10 +0,0 @@
1
- [theme]
2
- base = "dark"
3
- primaryColor = "#FFFFFF"
4
- backgroundColor = "#212121"
5
- secondaryBackgroundColor = "#757575"
6
- textColor = "#FFFFFF"
7
- font = "sans serif"
8
-
9
- [browser]
10
- gatherUsageStats = false
 
 
 
 
 
 
 
 
 
 
 
src/frontend/__init__.py DELETED
File without changes
src/frontend/ui.py DELETED
@@ -1,97 +0,0 @@
1
- import json
2
-
3
- import streamlit as st
4
- from ui_backend import (
5
- check_for_api,
6
- cut_audio_file,
7
- display_predictions,
8
- load_audio,
9
- predict_multiple,
10
- predict_single,
11
- )
12
-
13
-
14
- def main():
15
- # Page settings
16
- st.set_page_config(
17
- page_title="Music Instrument Recognition", page_icon="🎸", layout="wide", initial_sidebar_state="collapsed"
18
- )
19
-
20
- # Sidebar
21
- with st.sidebar:
22
- st.title("⚙️ Settings")
23
- selected_model = st.selectbox(
24
- "Select Model",
25
- ("Accuracy", "Speed"),
26
- index=0,
27
- help="Select a slower but more accurate model or a faster but less accurate model",
28
- )
29
-
30
- # Main title
31
- st.markdown(
32
- "<h1 style='text-align: center; color: #FFFFFF; font-size: 3rem;'>Instrument Recognition 🎶</h1>",
33
- unsafe_allow_html=True,
34
- )
35
-
36
- # Upload widget
37
- audio_file = load_audio()
38
-
39
- # Send a health check request to the API in a loop until it is running
40
- api_running = check_for_api(10)
41
-
42
- # Enable or disable a button based on API status
43
- predict_valid = False
44
- cut_valid = False
45
-
46
- if api_running:
47
- st.info("API is running", icon="🤖")
48
-
49
- if audio_file:
50
- num_files = len(audio_file)
51
- st.write(f"Number of uploaded files: {num_files}")
52
- predict_valid = True
53
- if len(audio_file) > 1:
54
- cut_valid = False
55
- else:
56
- audio_file = audio_file[0]
57
- cut_valid = True
58
- name = audio_file.name
59
-
60
- if cut_valid:
61
- cut_audio = st.checkbox(
62
- "✂️ Cut duration",
63
- disabled=not predict_valid,
64
- help="Cut a long audio file. Model works best if audio is around 15 seconds",
65
- )
66
-
67
- if cut_audio:
68
- audio_file = cut_audio_file(audio_file, name)
69
-
70
- result = st.button("Predict", disabled=not predict_valid, help="Send the audio to API to get a prediction")
71
-
72
- if result:
73
- predictions = {}
74
- if isinstance(audio_file, list):
75
- predictions = predict_multiple(audio_file, selected_model)
76
-
77
- else:
78
- predictions = predict_single(audio_file, name, selected_model)
79
-
80
- # Sort the dictionary alphabetically by key
81
- sorted_predictions = dict(sorted(predictions.items()))
82
-
83
- # Convert the sorted dictionary to a JSON string
84
- json_string = json.dumps(sorted_predictions)
85
- st.download_button(
86
- label="Download JSON",
87
- file_name="predictions.json",
88
- mime="application/json",
89
- data=json_string,
90
- help="Download the predictions in JSON format",
91
- )
92
-
93
- display_predictions(sorted_predictions)
94
-
95
-
96
- if __name__ == "__main__":
97
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/frontend/ui_backend.py DELETED
@@ -1,254 +0,0 @@
1
- import io
2
- import os
3
- import time
4
- from json import JSONDecodeError
5
- import math
6
-
7
- import requests
8
- import soundfile as sf
9
- import streamlit as st
10
-
11
- if os.environ.get("IS_DOCKER", False):
12
- backend = "http://api:7860"
13
- else:
14
- backend = "http://0.0.0.0:7860"
15
-
16
- INSTRUMENTS = {
17
- "tru": "Trumpet",
18
- "sax": "Saxophone",
19
- "vio": "Violin",
20
- "gac": "Acoustic Guitar",
21
- "org": "Organ",
22
- "cla": "Clarinet",
23
- "flu": "Flute",
24
- "voi": "Voice",
25
- "gel": "Electric Guitar",
26
- "cel": "Cello",
27
- "pia": "Piano",
28
- }
29
-
30
-
31
- def load_audio():
32
- """
33
- Upload a WAV audio file and display it in a Streamlit app.
34
-
35
- :return: A BytesIO object representing the uploaded audio file, or None if no file was uploaded.
36
- :rtype: Optional[BytesIO]
37
- """
38
-
39
- audio_file = st.file_uploader(label="Upload audio file", type="wav", accept_multiple_files=True)
40
- if len(audio_file) > 0:
41
- st.audio(audio_file[0])
42
- return audio_file
43
- else:
44
- return None
45
-
46
-
47
- @st.cache_data(show_spinner=False)
48
- def check_for_api(max_tries: int):
49
- """
50
- Check if the API is running by making a health check request.
51
-
52
- :param max_tries: The maximum number of attempts to check the API's health.
53
- :type max_tries: int
54
- :return: True if the API is running, False otherwise.
55
- :rtype: bool
56
- """
57
- trial_count = 0
58
-
59
- with st.spinner("Waiting for API..."):
60
- while trial_count <= max_tries:
61
- try:
62
- response = health_check()
63
- if response:
64
- return True
65
- except requests.exceptions.ConnectionError:
66
- trial_count += 1
67
- # Handle connection error, e.g. API not yet running
68
- time.sleep(5) # Sleep for 1 second before retrying
69
- st.error("API is not running. Please refresh the page to try again.", icon="🚨")
70
- st.stop()
71
-
72
-
73
- def cut_audio_file(audio_file, name):
74
- """
75
- Cut an audio file and return the cut audio data as a tuple.
76
-
77
- :param audio_file: The path of the audio file to be cut.
78
- :type audio_file: str
79
- :param name: The name of the audio file to be cut.
80
- :type name: str
81
- :raises RuntimeError: If the audio file cannot be read.
82
- :return: A tuple containing the name and the cut audio data as a BytesIO object.
83
- :rtype: tuple
84
- """
85
- try:
86
- audio_data, sample_rate = sf.read(audio_file)
87
- except RuntimeError as e:
88
- raise e
89
-
90
- # Display audio duration
91
- duration = round(len(audio_data) / sample_rate, 2)
92
- st.info(f"Audio Duration: {duration} seconds")
93
-
94
- # Get start and end time for cutting
95
- start_time = st.number_input("Start Time (seconds)", min_value=0.0, max_value=duration - 1, step=0.1)
96
- end_time = st.number_input("End Time (seconds)", min_value=start_time, value=duration, max_value=duration, step=0.1)
97
-
98
- # Convert start and end time to sample indices
99
- start_sample = int(start_time * sample_rate)
100
- end_sample = int(end_time * sample_rate)
101
-
102
- # Cut audio
103
- cut_audio_data = audio_data[start_sample:end_sample]
104
-
105
- # Create a temporary in-memory file for cut audio
106
- audio_file = io.BytesIO()
107
- sf.write(audio_file, cut_audio_data, sample_rate, format="wav")
108
-
109
- # Display cut audio
110
- st.audio(audio_file, format="audio/wav")
111
- audio_file = (name, audio_file)
112
-
113
- return audio_file
114
-
115
-
116
- def display_predictions(predictions: dict):
117
- """
118
- Display the predictions using instrument names instead of codes.
119
-
120
- :param predictions: A dictionary containing the filenames and instruments detected in them.
121
- :type predictions: dict
122
- """
123
-
124
- # Display the results using instrument names instead of codes
125
- for filename, instruments in predictions.items():
126
- st.subheader(filename)
127
-
128
- if isinstance(instruments, str):
129
- st.write(instruments)
130
-
131
- else:
132
- with st.container():
133
- col1, col2 = st.columns([1, 3])
134
- present_instruments = [
135
- INSTRUMENTS[instrument_code] for instrument_code, presence in instruments.items() if presence
136
- ]
137
- if present_instruments:
138
- for instrument_name in present_instruments:
139
- with col1:
140
- st.write(instrument_name)
141
- with col2:
142
- st.write("✔️")
143
- else:
144
- st.write("No instruments found in this file.")
145
-
146
-
147
- def health_check():
148
- """
149
- Sends a health check request to the API and checks if it's running.
150
-
151
- :return: Returns True if the API is running, else False.
152
- :rtype: bool
153
- """
154
-
155
- # Send a health check request to the API
156
- response = requests.get(f"{backend}/health-check", timeout=100)
157
-
158
- # Check if the API is running
159
- if response.status_code == 200:
160
- return True
161
- else:
162
- return False
163
-
164
-
165
- def predict(data, model_name):
166
- """
167
- Sends a POST request to the API with the provided data and model name.
168
-
169
- :param data: The audio data to be used for prediction.
170
- :type data: bytes
171
- :param model_name: The name of the model to be used for prediction.
172
- :type model_name: str
173
- :return: The response from the API.
174
- :rtype: requests.Response
175
- """
176
-
177
- file = {"file": data}
178
- request_data = {"model_name": model_name}
179
-
180
- response = requests.post(
181
- f"{backend}/predict", params=request_data, files=file, timeout=300
182
- ) # Replace with your API endpoint URL
183
-
184
- return response
185
-
186
-
187
- @st.cache_data(show_spinner=False)
188
- def predict_single(audio_file, name, selected_model):
189
- """
190
- Predicts the instruments in a single audio file using the selected model.
191
-
192
- :param audio_file: The audio file to be used for prediction.
193
- :type audio_file: bytes
194
- :param name: The name of the audio file.
195
- :type name: str
196
- :param selected_model: The name of the selected model.
197
- :type selected_model: str
198
- :return: A dictionary containing the predicted instruments for the audio file.
199
- :rtype: dict
200
- """
201
-
202
- predictions = {}
203
-
204
- with st.spinner("Predicting instruments..."):
205
- response = predict(audio_file, selected_model)
206
-
207
- if response.status_code == 200:
208
- prediction = response.json()["prediction"]
209
- predictions[name] = prediction.get(name, "Error making prediction")
210
- else:
211
- st.write(response)
212
- try:
213
- st.json(response.json())
214
- except JSONDecodeError:
215
- st.error(response.text)
216
- st.stop()
217
- return predictions
218
-
219
-
220
- @st.cache_data(show_spinner=False)
221
- def predict_multiple(audio_files, selected_model):
222
- """
223
- Generates predictions for multiple audio files using the selected model.
224
-
225
- :param audio_files: A list of audio files to make predictions on.
226
- :type audio_files: List[UploadedFile]
227
- :param selected_model: The model to use for making predictions.
228
- :type selected_model: str
229
- :return: A dictionary where the keys are the names of the audio files and the values are the predicted labels.
230
- :rtype: Dict[str, str]
231
- """
232
-
233
- predictions = {}
234
- progress_text = "Getting predictions for all files. Please wait."
235
- progress_bar = st.empty()
236
- progress_bar.progress(0, text=progress_text)
237
-
238
- num_files = len(audio_files)
239
-
240
- for i, file in enumerate(audio_files):
241
- name = file.name
242
- response = predict(file, selected_model)
243
- if response.status_code == 200:
244
- prediction = response.json()["prediction"]
245
- predictions[name] = prediction[name]
246
- progress_bar.progress((i + 1) / num_files, text=progress_text)
247
- else:
248
- predictions[name] = "Error making prediction."
249
- progress_bar.empty()
250
- return predictions
251
-
252
-
253
- if __name__ == "__main__":
254
- pass