haoheliu commited on
Commit
1f34ab8
1 Parent(s): 72bd0df

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +118 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torchaudio
3
+ import torch
4
+ import librosa
5
+ import librosa.display
6
+ import matplotlib.pyplot as plt
7
+ from semanticodec import SemantiCodec
8
+ import numpy as np
9
+ import tempfile
10
+ import os
11
+
12
+ # Set default parameters
13
+ DEFAULT_TOKEN_RATE = 100
14
+ DEFAULT_SEMANTIC_VOCAB_SIZE = 16384
15
+ DEFAULT_SAMPLE_RATE = 16000
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+
18
+ # Title and Description
19
+ st.title("SemantiCodec: Ultra-Low Bitrate Neural Audio Codec")
20
+ st.write("""
21
+ Upload your audio file, adjust the codec parameters, and compare the original and reconstructed audio.
22
+ SemantiCodec achieves high-quality audio reconstruction with ultra-low bitrates!
23
+ """)
24
+
25
+ # Sidebar: Parameters
26
+ st.sidebar.title("Codec Parameters")
27
+ token_rate = st.sidebar.selectbox("Token Rate (tokens/sec)", [25, 50, 100], index=2)
28
+ semantic_vocab_size = st.sidebar.selectbox(
29
+ "Semantic Vocabulary Size",
30
+ [4096, 8192, 16384, 32768],
31
+ index=2,
32
+ )
33
+ ddim_steps = st.sidebar.slider("DDIM Sampling Steps", 10, 100, 50, step=5)
34
+ guidance_scale = st.sidebar.slider("CFG Guidance Scale", 0.5, 5.0, 2.0, step=0.1)
35
+
36
+ # Upload Audio File
37
+ uploaded_file = st.file_uploader("Upload an audio file (WAV format)", type=["wav"])
38
+
39
+ # Helper function: Plot spectrogram
40
+ def plot_spectrogram(waveform, sample_rate, title):
41
+ plt.figure(figsize=(10, 4))
42
+ S = librosa.feature.melspectrogram(y=waveform, sr=sample_rate, n_mels=128, fmax=sample_rate // 2)
43
+ S_dB = librosa.power_to_db(S, ref=np.max)
44
+ librosa.display.specshow(S_dB, sr=sample_rate, x_axis='time', y_axis='mel', cmap='viridis')
45
+ plt.colorbar(format='%+2.0f dB')
46
+ plt.title(title)
47
+ plt.tight_layout()
48
+ st.pyplot(plt)
49
+
50
+ # Process Audio
51
+ if uploaded_file and st.button("Run SemantiCodec"):
52
+ with tempfile.TemporaryDirectory() as temp_dir:
53
+ # Save uploaded file
54
+ input_path = os.path.join(temp_dir, "input.wav")
55
+ with open(input_path, "wb") as f:
56
+ f.write(uploaded_file.read())
57
+
58
+ # Load audio
59
+ waveform, sample_rate = torchaudio.load(input_path)
60
+
61
+ # Check if resampling is needed
62
+ if sample_rate != DEFAULT_SAMPLE_RATE:
63
+ st.write(f"Resampling audio from {sample_rate} Hz to {DEFAULT_SAMPLE_RATE} Hz...")
64
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=DEFAULT_SAMPLE_RATE)
65
+ waveform = resampler(waveform)
66
+ sample_rate = DEFAULT_SAMPLE_RATE # Update sample rate to 16kHz
67
+
68
+ # Convert to numpy for librosa compatibility
69
+ waveform = waveform[0].numpy()
70
+
71
+ # Plot Original Spectrogram (16kHz resampled)
72
+ st.write("Original Audio Spectrogram (Resampled to 16kHz):")
73
+ plot_spectrogram(waveform, sample_rate, "Original Audio Spectrogram (Resampled to 16kHz)")
74
+
75
+ # Initialize SemantiCodec
76
+ st.write("Initializing SemantiCodec...")
77
+ semanticodec = SemantiCodec(
78
+ token_rate=token_rate,
79
+ semantic_vocab_size=semantic_vocab_size,
80
+ ddim_sample_step=ddim_steps,
81
+ cfg_scale=guidance_scale,
82
+ )
83
+ semanticodec.device = device
84
+ semanticodec.encoder = semanticodec.encoder.to(device)
85
+ semanticodec.decoder = semanticodec.decoder.to(device)
86
+
87
+ # Encode and Decode
88
+ st.write("Encoding and Decoding Audio...")
89
+ tokens = semanticodec.encode(input_path)
90
+ reconstructed_waveform = semanticodec.decode(tokens)[0, 0]
91
+
92
+ # Save reconstructed audio
93
+ reconstructed_path = os.path.join(temp_dir, "reconstructed.wav")
94
+ torchaudio.save(reconstructed_path, torch.tensor([reconstructed_waveform]), sample_rate)
95
+
96
+ # Plot Reconstructed Spectrogram
97
+ st.write("Reconstructed Audio Spectrogram:")
98
+ plot_spectrogram(reconstructed_waveform, sample_rate, "Reconstructed Audio Spectrogram")
99
+
100
+ # Display latent code shape
101
+ st.write(f"Shape of Latent Code: {tokens.shape}")
102
+
103
+ # Audio Players
104
+ st.audio(input_path, format="audio/wav")
105
+ st.write("Original Audio")
106
+ st.audio(reconstructed_path, format="audio/wav")
107
+ st.write("Reconstructed Audio")
108
+
109
+ # Download Button for Reconstructed Audio
110
+ st.download_button(
111
+ "Download Reconstructed Audio",
112
+ data=open(reconstructed_path, "rb").read(),
113
+ file_name="reconstructed_audio.wav",
114
+ )
115
+
116
+
117
+ # Footer
118
+ st.write("Built with [Streamlit](https://streamlit.io) and SemantiCodec")
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ git+https://github.com/haoheliu/SemantiCodec-inference.git
2
+ matplotlib
3
+ librosa
4
+ torch
5
+ torchaudio
6
+ streamlit