Ezi commited on
Commit
baa733f
1 Parent(s): f40f237

Upload streamlit_test_space.py

Browse files
Files changed (1) hide show
  1. streamlit_test_space.py +114 -0
streamlit_test_space.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import wavmark
3
+ import streamlit as st
4
+ import os
5
+ import torch
6
+ import datetime
7
+ import numpy as np
8
+ import soundfile
9
+ from wavmark.utils import file_reader
10
+
11
+
12
+ def my_read_file(audio_path, max_second):
13
+ signal, sr, audio_length_second = file_reader.read_as_single_channel_16k(audio_path, default_sr)
14
+ if audio_length_second > max_second:
15
+ signal = signal[0:default_sr * max_second]
16
+ audio_length_second = max_second
17
+
18
+ return signal, sr, audio_length_second
19
+
20
+
21
+ def add_watermark(audio_path, watermark_text):
22
+ #t1 = time.time()
23
+ assert len(watermark_text) == 16
24
+ watermark_npy = np.array([int(i) for i in watermark_text])
25
+ signal, sr, audio_length_second = my_read_file(audio_path, max_second_encode)
26
+ watermarked_signal, _ = wavmark.encode_watermark(model, signal, watermark_npy, show_progress=False)
27
+
28
+ tmp_file_name = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + "_" + watermark_text + ".wav"
29
+ tmp_file_path = '/tmp/' + tmp_file_name
30
+ soundfile.write(tmp_file_path, watermarked_signal, sr)
31
+ #encode_time_cost = time.time() - t1
32
+ return tmp_file_path
33
+
34
+ #def encode_water()
35
+
36
+ def decode_watermark(audio_path):
37
+ assert os.path.exists(audio_path)
38
+
39
+ #t1 = time.time()
40
+ signal, sr, audio_length_second = my_read_file(audio_path, max_second_decode)
41
+ payload_decoded, _ = wavmark.decode_watermark(model, signal, show_progress=False)
42
+ decode_cost = time.time() - t1
43
+
44
+ if payload_decoded is None:
45
+ return "No Watermark", decode_cost
46
+
47
+ payload_decoded_str = "".join([str(i) for i in payload_decoded])
48
+ st.write("Result:", payload_decoded_str)
49
+ #st.write("Time Cost:%d seconds" % (decode_cost))
50
+
51
+
52
+ def create_default_value():
53
+ if "def_value" not in st.session_state:
54
+ def_val_npy = np.random.choice([0, 1], size=32 - len_start_bit)
55
+ def_val_str = "".join([str(i) for i in def_val_npy])
56
+ st.session_state.def_value = def_val_str
57
+
58
+
59
+
60
+ def main():
61
+ create_default_value()
62
+
63
+ # st.title("AudioWaterMarking")
64
+ markdown_text = """
65
+ # Audio WaterMarking
66
+ You can upload an audio file and encode a custom 16-bit watermark or perform decoding from a watermarked audio.
67
+
68
+ See [WaveMarktoolkit](https://github.com/wavmark/wavmark) for further details.
69
+ """
70
+
71
+ st.markdown(markdown_text)
72
+
73
+ audio_file = st.file_uploader("Upload Audio", type=["wav", "mp3"], accept_multiple_files=False)
74
+
75
+ if audio_file:
76
+
77
+ tmp_input_audio_file = os.path.join("/tmp/", audio_file.name)
78
+ with open(tmp_input_audio_file, "wb") as f:
79
+ f.write(audio_file.getbuffer())
80
+
81
+
82
+ # st.audio(tmp_input_audio_file, format="audio/wav")
83
+
84
+ action = st.selectbox("Select Action", ["Add Watermark", "Decode Watermark"])
85
+
86
+ if action == "Add Watermark":
87
+ watermark_text = st.text_input("The watermark (0, 1 list of length-16):", value=st.session_state.def_value)
88
+ add_watermark_button = st.button("Add Watermark", key="add_watermark_btn")
89
+ if add_watermark_button:
90
+ if audio_file and watermark_text:
91
+ with st.spinner("Adding Watermark..."):
92
+ #watermarked_audio, encode_time_cost = add_watermark(tmp_input_audio_file, watermark_text)
93
+ watermarked_audio = add_watermark(tmp_input_audio_file, watermark_text)
94
+ st.write("Watermarked Audio:")
95
+ print("watermarked_audio:", watermarked_audio)
96
+ st.audio(watermarked_audio, format="audio/wav")
97
+ #st.write("Time Cost: %d seconds" % encode_time_cost)
98
+
99
+ elif action == "Decode Watermark":
100
+ if st.button("Decode"):
101
+ with st.spinner("Decoding..."):
102
+ decode_watermark(tmp_input_audio_file)
103
+
104
+
105
+ if __name__ == "__main__":
106
+ default_sr = 16000
107
+ max_second_encode = 60
108
+ max_second_decode = 30
109
+ len_start_bit = 16
110
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
111
+ model = wavmark.load_model().to(device)
112
+ main()
113
+
114
+