Spaces:
Running
Running
import os | |
import glob | |
import torchaudio | |
import torchaudio.transforms as T | |
import numpy as np | |
from matplotlib import pyplot as plt | |
import librosa | |
import librosa.display | |
from df import enhance, init_df | |
import streamlit as st | |
from streamlit.components.v1 import html | |
app_title = "μμ μ΅μ λꡬ" | |
model, df_state, _ = init_df() # Load default model | |
df_sr = 48000 | |
def display_audio_info(audio, title): | |
# λ κ°μ μ»¬λΌ μμ± | |
col1, col2 = st.columns(2) | |
audio = np.clip(audio, -1.0, 1.0) | |
if len(np.shape(audio)) == 2: | |
audio = audio[0] | |
# μΌμͺ½ 컬λΌμ μ€ννΈλ‘κ·Έλ¨ νμ | |
with col1: | |
st.markdown(f"### {title} - Spectrogram") | |
D = librosa.stft(audio) # STFT of y | |
S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max) | |
fig, ax = plt.subplots() | |
img = librosa.display.specshow( | |
S_db, x_axis='time', y_axis='linear', ax=ax) | |
fig.colorbar(img, ax=ax, format="%+2.f dB") | |
st.pyplot(fig) | |
# μ€λ₯Έμͺ½ 컬λΌμ νν νμ | |
with col2: | |
st.markdown(f"### {title} - Waveform") | |
fig, ax = plt.subplots() | |
plt.plot(audio) | |
ax.set_xticks([]) | |
ax.set_ylim(-1, 1) | |
st.pyplot(fig) | |
def main(): | |
st.set_page_config(page_title=app_title, page_icon="favicon.ico", | |
layout="centered", initial_sidebar_state="auto", menu_items=None) | |
button = """<script type="text/javascript" src="https://cdnjs.buymeacoffee.com/1.0.0/button.prod.min.js" data-name="bmc-button" data-slug="woojae" data-color="#FFDD00" data-emoji="β" data-font="Cookie" data-text="Buy me a coffee" data-outline-color="#000000" data-font-color="#000000" data-coffee-color="#ffffff" ></script>""" | |
st.title(app_title) | |
st.divider() | |
st.header('μμ½κ² λΆνμν μμμ μ κ±°νμΈμ!') | |
uploaded_file = st.file_uploader( | |
"λ³νν νμΌμ μ λ‘λ ν΄μ£ΌμΈμ. (μ§μ νμ: .wav, .mp3, .opus)") | |
if uploaded_file: | |
# μ΄μ μ λ€μ΄λ‘λ ν νμΌμ μμ | |
files_to_remove = glob.glob('enhanced_*') | |
for file in files_to_remove: | |
os.remove(file) | |
uploaded_file_type = uploaded_file.type.split('/')[-1] | |
print(uploaded_file_type) | |
if uploaded_file_type not in ['wav', 'mpeg', 'ogg']: | |
st.text('μ§μνμ§ μλ νμΌ νμμ λλ€.') | |
else: | |
with st.spinner('μμ μ κ±°νλ μ€'): | |
noisy_audio, sr = torchaudio.load(uploaded_file) | |
print("np.shape(noisy_audio)", np.shape(noisy_audio)) | |
st.audio(noisy_audio.numpy(), sample_rate=sr) | |
# μνλ§ λ μ΄νΈκ° 48000Hzκ° μλ κ²½μ° λ¦¬μνλ§ | |
if sr != df_sr: | |
resampler = T.Resample(orig_freq=sr, new_freq=df_sr) | |
noisy_audio = resampler(noisy_audio) | |
display_audio_info(noisy_audio.numpy(), "μ λ ₯") | |
with st.spinner('μμ μ κ±°νλ μ€'): | |
output_audio = enhance(model, df_state, noisy_audio) | |
enhanced_audio = output_audio | |
st.divider() | |
# μνλ§ λ μ΄νΈκ° 48000Hzκ° μλ κ²½μ° λ¦¬μνλ§ | |
if sr != df_sr: | |
resampler = T.Resample(orig_freq=df_sr, new_freq=sr) | |
enhanced_audio = resampler(enhanced_audio) | |
st.audio(enhanced_audio.numpy(), sample_rate=sr) | |
display_audio_info(output_audio.numpy(), "μΆλ ₯") | |
html(button, height=70, width=240) | |
st.markdown( | |
""" | |
<style> | |
iframe[width="240"] { | |
position: fixed; | |
bottom: 30px; | |
right: 10px; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
if __name__ == '__main__': | |
main() | |