Spaces:
Running
Running
Update src/gradio_demo.py
Browse files- src/gradio_demo.py +170 -169
src/gradio_demo.py
CHANGED
@@ -1,170 +1,171 @@
|
|
1 |
-
import
|
2 |
-
import
|
3 |
-
|
4 |
-
from src.
|
5 |
-
from src.
|
6 |
-
from src.
|
7 |
-
from src.
|
8 |
-
from src.
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
mp3_file.
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
self.
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
self.
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
os.
|
69 |
-
|
70 |
-
|
71 |
-
os.
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
from
|
90 |
-
|
91 |
-
one_sec_segment.
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
os.
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
os.
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
del self.
|
160 |
-
del self.
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
torch.cuda.
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
|
|
170 |
|
|
|
1 |
+
import spaces
|
2 |
+
import torch, uuid
|
3 |
+
import os, sys, shutil, platform
|
4 |
+
from src.facerender.pirender_animate import AnimateFromCoeff_PIRender
|
5 |
+
from src.utils.preprocess import CropAndExtract
|
6 |
+
from src.test_audio2coeff import Audio2Coeff
|
7 |
+
from src.facerender.animate import AnimateFromCoeff
|
8 |
+
from src.generate_batch import get_data
|
9 |
+
from src.generate_facerender_batch import get_facerender_data
|
10 |
+
|
11 |
+
from src.utils.init_path import init_path
|
12 |
+
|
13 |
+
from pydub import AudioSegment
|
14 |
+
|
15 |
+
|
16 |
+
def mp3_to_wav(mp3_filename,wav_filename,frame_rate):
|
17 |
+
mp3_file = AudioSegment.from_file(file=mp3_filename)
|
18 |
+
mp3_file.set_frame_rate(frame_rate).export(wav_filename,format="wav")
|
19 |
+
|
20 |
+
|
21 |
+
class SadTalker():
|
22 |
+
|
23 |
+
def __init__(self, checkpoint_path='checkpoints', config_path='src/config', lazy_load=False):
|
24 |
+
|
25 |
+
if torch.cuda.is_available():
|
26 |
+
device = "cuda"
|
27 |
+
elif platform.system() == 'Darwin': # macos
|
28 |
+
device = "mps"
|
29 |
+
else:
|
30 |
+
device = "cpu"
|
31 |
+
|
32 |
+
self.device = device
|
33 |
+
|
34 |
+
os.environ['TORCH_HOME']= checkpoint_path
|
35 |
+
|
36 |
+
self.checkpoint_path = checkpoint_path
|
37 |
+
self.config_path = config_path
|
38 |
+
|
39 |
+
@spaces.GPU
|
40 |
+
def test(self, source_image, driven_audio, preprocess='crop',
|
41 |
+
still_mode=False, use_enhancer=False, batch_size=1, size=256,
|
42 |
+
pose_style = 0,
|
43 |
+
facerender='facevid2vid',
|
44 |
+
exp_scale=1.0,
|
45 |
+
use_ref_video = False,
|
46 |
+
ref_video = None,
|
47 |
+
ref_info = None,
|
48 |
+
use_idle_mode = False,
|
49 |
+
length_of_audio = 0, use_blink=True,
|
50 |
+
result_dir='./results/'):
|
51 |
+
|
52 |
+
self.sadtalker_paths = init_path(self.checkpoint_path, self.config_path, size, False, preprocess)
|
53 |
+
print(self.sadtalker_paths)
|
54 |
+
|
55 |
+
self.audio_to_coeff = Audio2Coeff(self.sadtalker_paths, self.device)
|
56 |
+
self.preprocess_model = CropAndExtract(self.sadtalker_paths, self.device)
|
57 |
+
|
58 |
+
if facerender == 'facevid2vid' and self.device != 'mps':
|
59 |
+
self.animate_from_coeff = AnimateFromCoeff(self.sadtalker_paths, self.device)
|
60 |
+
elif facerender == 'pirender' or self.device == 'mps':
|
61 |
+
self.animate_from_coeff = AnimateFromCoeff_PIRender(self.sadtalker_paths, self.device)
|
62 |
+
facerender = 'pirender'
|
63 |
+
else:
|
64 |
+
raise(RuntimeError('Unknown model: {}'.format(facerender)))
|
65 |
+
|
66 |
+
|
67 |
+
time_tag = str(uuid.uuid4())
|
68 |
+
save_dir = os.path.join(result_dir, time_tag)
|
69 |
+
os.makedirs(save_dir, exist_ok=True)
|
70 |
+
|
71 |
+
input_dir = os.path.join(save_dir, 'input')
|
72 |
+
os.makedirs(input_dir, exist_ok=True)
|
73 |
+
|
74 |
+
print(source_image)
|
75 |
+
pic_path = os.path.join(input_dir, os.path.basename(source_image))
|
76 |
+
shutil.move(source_image, input_dir)
|
77 |
+
|
78 |
+
if driven_audio is not None and os.path.isfile(driven_audio):
|
79 |
+
audio_path = os.path.join(input_dir, os.path.basename(driven_audio))
|
80 |
+
|
81 |
+
#### mp3 to wav
|
82 |
+
if '.mp3' in audio_path:
|
83 |
+
mp3_to_wav(driven_audio, audio_path.replace('.mp3', '.wav'), 16000)
|
84 |
+
audio_path = audio_path.replace('.mp3', '.wav')
|
85 |
+
else:
|
86 |
+
shutil.move(driven_audio, input_dir)
|
87 |
+
|
88 |
+
elif use_idle_mode:
|
89 |
+
audio_path = os.path.join(input_dir, 'idlemode_'+str(length_of_audio)+'.wav') ## generate audio from this new audio_path
|
90 |
+
from pydub import AudioSegment
|
91 |
+
one_sec_segment = AudioSegment.silent(duration=1000*length_of_audio) #duration in milliseconds
|
92 |
+
one_sec_segment.export(audio_path, format="wav")
|
93 |
+
else:
|
94 |
+
print(use_ref_video, ref_info)
|
95 |
+
assert use_ref_video == True and ref_info == 'all'
|
96 |
+
|
97 |
+
if use_ref_video and ref_info == 'all': # full ref mode
|
98 |
+
ref_video_videoname = os.path.basename(ref_video)
|
99 |
+
audio_path = os.path.join(save_dir, ref_video_videoname+'.wav')
|
100 |
+
print('new audiopath:',audio_path)
|
101 |
+
# if ref_video contains audio, set the audio from ref_video.
|
102 |
+
cmd = r"ffmpeg -y -hide_banner -loglevel error -i %s %s"%(ref_video, audio_path)
|
103 |
+
os.system(cmd)
|
104 |
+
|
105 |
+
os.makedirs(save_dir, exist_ok=True)
|
106 |
+
|
107 |
+
#crop image and extract 3dmm from image
|
108 |
+
first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
|
109 |
+
os.makedirs(first_frame_dir, exist_ok=True)
|
110 |
+
first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(pic_path, first_frame_dir, preprocess, True, size)
|
111 |
+
|
112 |
+
if first_coeff_path is None:
|
113 |
+
raise AttributeError("No face is detected")
|
114 |
+
|
115 |
+
if use_ref_video:
|
116 |
+
print('using ref video for genreation')
|
117 |
+
ref_video_videoname = os.path.splitext(os.path.split(ref_video)[-1])[0]
|
118 |
+
ref_video_frame_dir = os.path.join(save_dir, ref_video_videoname)
|
119 |
+
os.makedirs(ref_video_frame_dir, exist_ok=True)
|
120 |
+
print('3DMM Extraction for the reference video providing pose')
|
121 |
+
ref_video_coeff_path, _, _ = self.preprocess_model.generate(ref_video, ref_video_frame_dir, preprocess, source_image_flag=False)
|
122 |
+
else:
|
123 |
+
ref_video_coeff_path = None
|
124 |
+
|
125 |
+
if use_ref_video:
|
126 |
+
if ref_info == 'pose':
|
127 |
+
ref_pose_coeff_path = ref_video_coeff_path
|
128 |
+
ref_eyeblink_coeff_path = None
|
129 |
+
elif ref_info == 'blink':
|
130 |
+
ref_pose_coeff_path = None
|
131 |
+
ref_eyeblink_coeff_path = ref_video_coeff_path
|
132 |
+
elif ref_info == 'pose+blink':
|
133 |
+
ref_pose_coeff_path = ref_video_coeff_path
|
134 |
+
ref_eyeblink_coeff_path = ref_video_coeff_path
|
135 |
+
elif ref_info == 'all':
|
136 |
+
ref_pose_coeff_path = None
|
137 |
+
ref_eyeblink_coeff_path = None
|
138 |
+
else:
|
139 |
+
raise('error in refinfo')
|
140 |
+
else:
|
141 |
+
ref_pose_coeff_path = None
|
142 |
+
ref_eyeblink_coeff_path = None
|
143 |
+
|
144 |
+
#audio2ceoff
|
145 |
+
if use_ref_video and ref_info == 'all':
|
146 |
+
coeff_path = ref_video_coeff_path # self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
|
147 |
+
else:
|
148 |
+
batch = get_data(first_coeff_path, audio_path, self.device, ref_eyeblink_coeff_path=ref_eyeblink_coeff_path, still=still_mode, \
|
149 |
+
idlemode=use_idle_mode, length_of_audio=length_of_audio, use_blink=use_blink) # longer audio?
|
150 |
+
coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
|
151 |
+
|
152 |
+
#coeff2video
|
153 |
+
data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode, \
|
154 |
+
preprocess=preprocess, size=size, expression_scale = exp_scale, facemodel=facerender)
|
155 |
+
return_path = self.animate_from_coeff.generate(data, save_dir, pic_path, crop_info, enhancer='gfpgan' if use_enhancer else None, preprocess=preprocess, img_size=size)
|
156 |
+
video_name = data['video_name']
|
157 |
+
print(f'The generated video is named {video_name} in {save_dir}')
|
158 |
+
|
159 |
+
del self.preprocess_model
|
160 |
+
del self.audio_to_coeff
|
161 |
+
del self.animate_from_coeff
|
162 |
+
|
163 |
+
if torch.cuda.is_available():
|
164 |
+
torch.cuda.empty_cache()
|
165 |
+
torch.cuda.synchronize()
|
166 |
+
|
167 |
+
import gc; gc.collect()
|
168 |
+
|
169 |
+
return return_path
|
170 |
+
|
171 |
|