Vahe commited on
Commit
5508eec
1 Parent(s): 003b2a5

app.py added

Browse files
Files changed (1) hide show
  1. app.py +370 -0
app.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ from glob import glob
4
+ from pathlib import Path
5
+
6
+ # from TTS.TTS.api import TTS
7
+ from TTS.utils.synthesizer import Synthesizer
8
+ from Wav2Lip.video_generator import create_video
9
+ from diffusers import StableDiffusionPipeline
10
+ from diffusers import LMSDiscreteScheduler
11
+
12
+ gpu = False
13
+ model_path = Path(r"tss_model/model_file.pth")
14
+ config_path = Path(r"tss_model/config.json")
15
+ vocoder_path = None
16
+ vocoder_config_path = None
17
+ model_dir = None
18
+ language="en"
19
+ file_path="generated_audio.wav"
20
+ speaker = None
21
+ split_sentences = True
22
+ pipe_out = None
23
+
24
+ # def get_synthesizer(model_path, config_path, vocoder_path, vocoder_config_path, model_dir, gpu):
25
+
26
+ synthesizer = Synthesizer(
27
+ tts_checkpoint=model_path,
28
+ tts_config_path=config_path,
29
+ tts_speakers_file=None,
30
+ tts_languages_file=None,
31
+ vocoder_checkpoint=vocoder_path,
32
+ vocoder_config=vocoder_config_path,
33
+ encoder_checkpoint=None,
34
+ encoder_config=None,
35
+ model_dir=model_dir,
36
+ use_cuda=gpu,
37
+ )
38
+
39
+ # return synthesizer
40
+
41
+ # synthesizer = get_synthesizer(model_path, config_path, vocoder_path, vocoder_config_path, model_dir, gpu)
42
+
43
+ def get_audio(synthesizer, speaker, language, speaker_wav, split_sentences, text):
44
+
45
+ wav = synthesizer.tts(
46
+ text=text,
47
+ speaker_name=speaker,
48
+ language_name=language,
49
+ speaker_wav=speaker_wav,
50
+ reference_wav=None,
51
+ style_wav=None,
52
+ style_text=None,
53
+ reference_speaker_name=None,
54
+ split_sentences=split_sentences
55
+ )
56
+
57
+ synthesizer.save_wav(wav=wav, path=file_path, pipe_out=pipe_out)
58
+
59
+ # avatar_images_dir = Path('avatar_images')
60
+ avatar_images_list = os.listdir('avatar_images')
61
+ avatar_names_list = list(map(lambda x: x.split('.')[0], avatar_images_list))
62
+
63
+ n_cols_avatars = 3
64
+ n_rows_avatars = int((len(avatar_images_list) - len(avatar_images_list) % n_cols_avatars) / n_cols_avatars)
65
+ if len(avatar_images_list) % n_cols_avatars != 0:
66
+ n_rows_avatars += 1
67
+
68
+ voice_audio_list = os.listdir('voice_audios')
69
+ voice_names_list = list(map(lambda x: x.split('.')[0], voice_audio_list))
70
+
71
+ n_cols_voices = 3
72
+ n_rows_voices = int((len(voice_audio_list) - len(voice_audio_list) % n_cols_voices) / n_cols_voices)
73
+ if len(voice_audio_list) % n_cols_voices != 0:
74
+ n_rows_voices += 1
75
+
76
+ st.set_page_config(
77
+ page_title='Avatar service',
78
+ layout='wide'
79
+ )
80
+
81
+ st.markdown("<h1 style='text-align: center; color: white;'>Avatar video generation</h1>", unsafe_allow_html=True)
82
+
83
+ # st.title('Avatar video generation')
84
+
85
+ st.subheader('Step 1: Avatar Selection')
86
+
87
+ with st.expander('Available avatars'):
88
+ n_images_shown = 0
89
+ for i in range(n_rows_avatars):
90
+ avatar_cols_list = st.columns(n_cols_avatars)
91
+ for j in range(n_cols_avatars):
92
+ avatar_cols_list[j].image(
93
+ os.path.join('avatar_images', avatar_images_list[j+i*3]),
94
+ width=150,
95
+ caption=avatar_names_list[j+i*3]
96
+ )
97
+ n_images_shown += 1
98
+ if n_images_shown == len(avatar_images_list):
99
+ break
100
+
101
+ def avatar_callback():
102
+ if st.session_state.avatar_image:
103
+ st.session_state.selected_avatar = st.session_state.avatar_image
104
+
105
+ if os.path.isfile('generated_avatar.jpg'):
106
+ os.remove('generated_avatar.jpg')
107
+
108
+ # if os.path.isfile('uploaded_avatar_image.jpg'):
109
+ # os.remove('uploaded_avatar_image.jpg')
110
+
111
+ def uploaded_avatar_callback():
112
+ if st.session_state.uploaded_avatar_image is None:
113
+ pass
114
+ else:
115
+ image_path = "uploaded_avatar_image" + \
116
+ os.path.splitext(st.session_state.uploaded_avatar_image.name)[-1]
117
+ with open(image_path, "wb") as f:
118
+ f.write(st.session_state.uploaded_avatar_image.getvalue())
119
+
120
+ step1_col1, step1_col2 = st.columns(2)
121
+
122
+ with step1_col1:
123
+ selected_avatar = st.selectbox(
124
+ label='Please select an avatar',
125
+ options=avatar_names_list,
126
+ key='avatar_image',
127
+ on_change=avatar_callback
128
+ )
129
+
130
+ st.write('or')
131
+
132
+ uploaded_image = st.file_uploader(
133
+ label='Please upload an avatar',
134
+ type=['png', 'jpg', 'jpeg'],
135
+ on_change=uploaded_avatar_callback,
136
+ key='uploaded_avatar_image'
137
+ )
138
+
139
+ st.write('or')
140
+
141
+ st.text_area(
142
+ label='Please type a prompt to generate an image for the avatar',
143
+ key='image_prompt'
144
+ )
145
+
146
+ def generate_avatar():
147
+ if st.session_state.avatar_generator:
148
+
149
+ # if not os.path.exists('generated_avatars'):
150
+ # os.mkdir('generated_avatars')
151
+
152
+ pipe = StableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path='diffusion_model')
153
+ pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
154
+ pipe_output = pipe(
155
+ prompt=st.session_state.image_prompt, # What to generate
156
+ negative_prompt="Oversaturated, blurry, low quality, do not show head", # What NOT to generate
157
+ height=480,
158
+ width=640, # Specify the image size
159
+ guidance_scale=13, # How strongly to follow the prompt
160
+ num_inference_steps=40, # How many steps to take
161
+ # generator=generator, # Fixed random seed
162
+ )
163
+ pipe_output.images[0].save('generated_avatar.jpg')
164
+ else:
165
+ pass
166
+
167
+ st.button(
168
+ label='generate_avatar',
169
+ key='avatar_generator',
170
+ on_click=generate_avatar
171
+ )
172
+
173
+ # st.write(st.session_state.avatar_generator)
174
+
175
+ with step1_col2:
176
+ if uploaded_image is not None:
177
+ uploaded_avatar_image_path = glob('uploaded_avatar_image.*')[0]
178
+ st.image(uploaded_avatar_image_path, width=300)
179
+ elif len(glob('generated_avatar.*')) != 0:
180
+ st.image('generated_avatar.jpg', width=300)
181
+ else:
182
+ st.image(os.path.join('avatar_images', avatar_images_list[avatar_names_list.index(selected_avatar)]), width=300)
183
+
184
+
185
+ st.subheader('Step 2: Audio Selection')
186
+ # st.markdown("<div title='Opa'>Option 1</div>", unsafe_allow_html=True)
187
+ option1_expander = st.expander('Option 1')
188
+ option1_expander.write(
189
+ '''Please select or upload an audio with a voice you want to be used in the video.
190
+ Then provide a text that will be used in the video. Afterwards click on
191
+ <Generate audio from text> button to get the audio which will be used in the video:
192
+ please, take into account that depending on the size of the text it may take some time.
193
+ '''
194
+ )
195
+
196
+ with st.expander('Available voice audio'):
197
+ n_voices_shown = 0
198
+ for i in range(n_rows_voices):
199
+ voice_cols_list = st.columns(n_cols_voices)
200
+ for j in range(n_cols_avatars):
201
+ voice_cols_list[j].audio(
202
+ os.path.join('voice_audios', voice_audio_list[j+i*3])
203
+ )
204
+ voice_cols_list[j].write(voice_names_list[j+i*3])
205
+ n_voices_shown += 1
206
+ if n_voices_shown == len(voice_audio_list):
207
+ break
208
+
209
+ def voice_callback():
210
+ if st.session_state.voice_audio:
211
+ st.session_state.selected_voice = st.session_state.voice_audio
212
+
213
+ def uploaded_voice_callback():
214
+ if st.session_state.uploaded_voice_audio is None:
215
+ pass
216
+ else:
217
+ audio_path = "uploaded_voice_audio" + \
218
+ os.path.splitext(st.session_state.uploaded_voice_audio.name)[-1]
219
+ with open(audio_path, "wb") as f:
220
+ f.write(st.session_state.uploaded_voice_audio.getvalue())
221
+
222
+ step21_col1, step21_col2 = st.columns(2)
223
+
224
+ with step21_col1:
225
+ selected_voice = st.selectbox(
226
+ label='Please select a voice to clone',
227
+ options=voice_names_list,
228
+ key='voice_audio',
229
+ on_change=voice_callback
230
+ )
231
+
232
+ st.write('or')
233
+
234
+ uploaded_voice = st.file_uploader(
235
+ "Upload a voice to clone",
236
+ type=['mp3', 'wav'],
237
+ key='uploaded_voice_audio',
238
+ on_change=uploaded_voice_callback
239
+ )
240
+
241
+ with step21_col2:
242
+ st.markdown('<br>', unsafe_allow_html=True)
243
+ if uploaded_voice is None:
244
+ st.audio(os.path.join('voice_audios', voice_audio_list[voice_names_list.index(selected_voice)]))
245
+ else:
246
+ uploaded_voice_audio_path = glob('uploaded_voice_audio.*')[0]
247
+ st.audio(uploaded_voice_audio_path)
248
+
249
+ step21txt_col1, step21txt_col2 = st.columns(2)
250
+
251
+ with step21txt_col1:
252
+ uploaded_txt = st.text_area(
253
+ label='Please input text for avatar',
254
+ key='txt4audio'
255
+ )
256
+
257
+ def generate_audio():
258
+ if st.session_state.audio_button:
259
+
260
+ if uploaded_voice is None:
261
+ speaker_wav = os.path.join('voice_audios', voice_audio_list[voice_names_list.index(selected_voice)])
262
+ else:
263
+ speaker_wav = "uploaded_voice_audio.mp3"
264
+
265
+ get_audio(
266
+ synthesizer, speaker, language,
267
+ speaker_wav, split_sentences,
268
+ text=st.session_state.txt4audio
269
+ )
270
+
271
+ with step21txt_col2:
272
+ st.markdown('<br>', unsafe_allow_html=True)
273
+ st.button(
274
+ label='Generate audio from text',
275
+ key='audio_button',
276
+ on_click=generate_audio
277
+ )
278
+
279
+ if st.session_state.audio_button:
280
+ gen_audio_col1, _ = st.columns(2)
281
+ gen_audio_col1.audio("generated_audio.wav")
282
+
283
+ # st.subheader('Step 2 - Option 2')
284
+
285
+ option1_expander = st.expander('Option 2')
286
+ option1_expander.write(
287
+ '''Please, just upload an audio that will be reproduced in the video.
288
+ '''
289
+ )
290
+
291
+ def uploaded_audio_callback():
292
+ if st.session_state.uploaded_audio is None:
293
+ pass
294
+ else:
295
+ audio_path = "uploaded_audio" + \
296
+ os.path.splitext(st.session_state.uploaded_audio.name)[-1]
297
+ with open(audio_path, "wb") as f:
298
+ f.write(st.session_state.uploaded_audio.getvalue())
299
+
300
+ step22_col1, step22_col2 = st.columns(2)
301
+
302
+ with step22_col1:
303
+ uploaded_audio = st.file_uploader(
304
+ "Please, upload an audio",
305
+ type=['mp3', 'wav'],
306
+ key='uploaded_audio',
307
+ on_change=uploaded_audio_callback
308
+ )
309
+
310
+ with step22_col2:
311
+ st.markdown('<br>', unsafe_allow_html=True)
312
+ if uploaded_audio is None:
313
+ pass
314
+ else:
315
+ st.audio(glob('uploaded_audio.*')[0])
316
+
317
+ st.subheader('Step 3')
318
+
319
+ def generate_video():
320
+ if st.session_state.video_button:
321
+
322
+ if uploaded_audio is None:
323
+ voice_audio = glob('generated_audio.*')[0]
324
+ else:
325
+ voice_audio = glob('uploaded_audio.*')[0]
326
+
327
+ # if st.session_state.audio_button:
328
+ # voice_audio = glob('generated_audio.*')[0]
329
+ # else:
330
+ # voice_audio = os.path.join('voice_audios', voice_audio_list[voice_names_list.index(selected_voice)])
331
+
332
+ if uploaded_image is not None:
333
+ face = glob('uploaded_avatar_image.*')[0]
334
+ elif len(glob('generated_avatar.*')) != 0:
335
+ face = glob('generated_avatar.*')[0]
336
+ else:
337
+ face = os.path.join('avatar_images', avatar_images_list[avatar_names_list.index(selected_avatar)])
338
+
339
+
340
+ create_video(voice_audio, face)
341
+
342
+ step3_button_col1, _, _ = st.columns([3, 4, 5])
343
+
344
+ with step3_button_col1:
345
+ st.button(
346
+ label='Generate video',
347
+ key='video_button',
348
+ on_click=generate_video
349
+ )
350
+
351
+ if st.session_state.video_button:
352
+
353
+ step3_col1, _, _ = st.columns([4, 3, 5])
354
+
355
+ with step3_col1:
356
+ st.video(
357
+ # os.path.join('avatar_videos', 'generated_video.mp4')
358
+ 'generated_video.mp4'
359
+ )
360
+
361
+ # with step3_col2:
362
+ # # st.markdown('<br>', unsafe_allow_html=True)
363
+ # # with open(os.path.join('avatar_videos', 'generated_video.mp4'), 'rb') as file:
364
+ # with open('generated_video.mp4', 'rb') as file:
365
+ # st.download_button(
366
+ # label='Download generated video',
367
+ # data=file,
368
+ # file_name='avatar_video.mp4',
369
+ # mime='video/mp4'
370
+ # )