|
import json |
|
|
|
import streamlit as st |
|
from ui_backend import ( |
|
check_for_api, |
|
cut_audio_file, |
|
display_predictions, |
|
load_audio, |
|
predict_multiple, |
|
predict_single, |
|
) |
|
|
|
|
|
def main(): |
|
|
|
st.set_page_config( |
|
page_title="Music Instrument Recognition", page_icon="๐ธ", layout="wide", initial_sidebar_state="collapsed" |
|
) |
|
|
|
|
|
with st.sidebar: |
|
st.title("โ๏ธ Settings") |
|
selected_model = st.selectbox( |
|
"Select Model", |
|
("Accuracy", "Speed"), |
|
index=0, |
|
help="Select a slower but more accurate model or a faster but less accurate model", |
|
) |
|
|
|
|
|
st.markdown( |
|
"<h1 style='text-align: center; color: #FFFFFF; font-size: 3rem;'>Instrument Recognition ๐ถ</h1>", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
|
|
audio_file = load_audio() |
|
|
|
|
|
api_running = check_for_api(10) |
|
|
|
|
|
predict_valid = False |
|
cut_valid = False |
|
|
|
if api_running: |
|
st.info("API is running", icon="๐ค") |
|
|
|
if audio_file: |
|
num_files = len(audio_file) |
|
st.write(f"Number of uploaded files: {num_files}") |
|
predict_valid = True |
|
if len(audio_file) > 1: |
|
cut_valid = False |
|
else: |
|
audio_file = audio_file[0] |
|
cut_valid = True |
|
name = audio_file.name |
|
|
|
if cut_valid: |
|
cut_audio = st.checkbox( |
|
"โ๏ธ Cut duration", |
|
disabled=not predict_valid, |
|
help="Cut a long audio file. Model works best if audio is around 15 seconds", |
|
) |
|
|
|
if cut_audio: |
|
audio_file = cut_audio_file(audio_file, name) |
|
|
|
result = st.button("Predict", disabled=not predict_valid, help="Send the audio to API to get a prediction") |
|
|
|
if result: |
|
predictions = {} |
|
if isinstance(audio_file, list): |
|
predictions = predict_multiple(audio_file, selected_model) |
|
|
|
else: |
|
predictions = predict_single(audio_file, name, selected_model) |
|
|
|
|
|
sorted_predictions = dict(sorted(predictions.items())) |
|
|
|
|
|
json_string = json.dumps(sorted_predictions) |
|
st.download_button( |
|
label="Download JSON", |
|
file_name="predictions.json", |
|
mime="application/json", |
|
data=json_string, |
|
help="Download the predictions in JSON format", |
|
) |
|
|
|
display_predictions(sorted_predictions) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|