hytian2@gmail.com
update
eae1cca
raw
history blame
12.8 kB
import os
# set CUDA_MODULE_LOADING=LAZY to speed up the serverless function
os.environ["CUDA_MODULE_LOADING"] = "LAZY"
# set SAFETENSORS_FAST_GPU=1 to speed up the serverless function
os.environ["SAFETENSORS_FAST_GPU"] = "1"
import cv2
import torch
import time
import imageio
import numpy as np
from tqdm import tqdm
import moviepy.editor as mp
import torch
from audio import load_wav, melspectrogram
from fete_model import FETE_model
from preprocess_videos import face_detect, load_from_npz
fps = 25
mel_idx_multiplier = 80.0 / fps
mel_step_size = 16
batch_size = 64 if torch.cuda.is_available() else 4
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} for inference.".format(device))
use_fp16 = True if torch.cuda.is_available() else False
print("Using FP16 for inference.") if use_fp16 else None
torch.backends.cudnn.benchmark = True if device == "cuda" else False
def init_model():
checkpoint_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "checkpoints/obama-fp16.safetensors")
model = FETE_model()
if checkpoint_path.endswith(".pth") or checkpoint_path.endswith(".ckpt"):
if device == "cuda":
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
s = checkpoint["state_dict"]
else:
from safetensors import safe_open
s = {}
with safe_open(checkpoint_path, framework="pt", device=device) as f:
for key in f.keys():
s[key] = f.get_tensor(key)
new_s = {}
for k, v in s.items():
new_s[k.replace("module.", "")] = v
model.load_state_dict(new_s)
model = model.to(device)
model.eval()
print("Model loaded")
if use_fp16:
for name, module in model.named_modules():
if ".query_conv" in name or ".key_conv" in name or ".value_conv" in name:
# keep attention layers in full precision to avoid error
module.to(torch.float)
else:
module.to(torch.half)
print("Model converted to half precision to accelerate inference")
return model
def make_mask(image_size=256, border_size=32):
mask_bar = np.linspace(1, 0, border_size).reshape(1, -1).repeat(image_size, axis=0)
mask = np.zeros((image_size, image_size), dtype=np.float32)
mask[-border_size:, :] += mask_bar.T[::-1]
mask[:, :border_size] = mask_bar
mask[:, -border_size:] = mask_bar[:, ::-1]
mask[-border_size:, :][mask[-border_size:, :] < 0.6] = 0.6
mask = np.stack([mask] * 3, axis=-1).astype(np.float32)
return mask
face_mask = make_mask()
def blend_images(foreground, background):
# Blend the foreground and background images using the mask
temp_mask = cv2.resize(face_mask, (foreground.shape[1], foreground.shape[0]))
blended = cv2.multiply(foreground.astype(np.float32), temp_mask)
blended += cv2.multiply(background.astype(np.float32), 1 - temp_mask)
blended = np.clip(blended, 0, 255).astype(np.uint8)
return blended
def smooth_coord(last_coord, current_coord, factor=0.4):
change = np.array(current_coord) - np.array(last_coord)
change = change * factor
return (np.array(last_coord) + np.array(change)).astype(int).tolist()
def add_black(imgs):
for i in range(len(imgs)):
# print('x', imgs[i].shape)
imgs[i] = cv2.vconcat(
[np.zeros((100, imgs[i].shape[1], 3), dtype=np.uint8), imgs[i], np.zeros((20, imgs[i].shape[1], 3), dtype=np.uint8)]
)
# imgs[i] = cv2.hconcat([np.zeros((imgs[i].shape[0], 100, 3), dtype=np.uint8), imgs[i], np.zeros((imgs[i].shape[0], 100, 3), dtype=np.uint8)])[:480+150,740-100:-740+100,:]
# print('xx', imgs[i].shape)
return imgs
def remove_black(img):
return img[100:-20]
def resize_length(input_attributes, length):
input_attributes = np.array(input_attributes)
resized_attributes = [input_attributes[int(i_ * (input_attributes.shape[0] / length))] for i_ in range(length)]
return np.array(resized_attributes).T
def output_chunks(input_attributes):
output_chunks = []
len_ = len(input_attributes[0])
i = 0
# print(mel.shape, pose.shape)
# (80, 801) (3, 801)
while 1:
start_idx = int(i * mel_idx_multiplier)
if start_idx + mel_step_size > len_:
output_chunks.append(input_attributes[:, len_ - mel_step_size :])
break
output_chunks.append(input_attributes[:, start_idx : start_idx + mel_step_size])
i += 1
return output_chunks
def prepare_data(face_path, audio_path, pose, emotion, blink, img_size=256, pads=[0, 0, 0, 0]):
if os.path.isfile(face_path) and face_path.split(".")[1] in ["jpg", "png", "jpeg"]:
static = True
full_frames = [cv2.imread(face_path)]
else:
static = False
video_stream = cv2.VideoCapture(face_path)
# print('Reading video frames...')
full_frames = []
while 1:
still_reading, frame = video_stream.read()
if not still_reading:
video_stream.release()
break
full_frames.append(frame)
print("Number of frames available for inference: " + str(len(full_frames)))
wav = load_wav(audio_path, 16000)
mel = melspectrogram(wav)
# take half
len_ = mel.shape[1] # //2
mel = mel[:, :len_]
# print('>>>', mel.shape)
pose = resize_length(pose, len_)
emotion = resize_length(emotion, len_)
blink = resize_length(blink, len_)
if np.isnan(mel.reshape(-1)).sum() > 0:
raise ValueError("Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again")
mel_chunks = output_chunks(mel)
pose_chunks = output_chunks(pose)
emotion_chunks = output_chunks(emotion)
blink_chunks = output_chunks(blink)
gen = datagen(face_path, full_frames, mel_chunks, pose_chunks, emotion_chunks, blink_chunks, static=static, img_size=img_size, pads=pads)
steps = int(np.ceil(float(len(mel_chunks)) / batch_size))
return gen, steps
def preprocess_batch(batch):
return torch.FloatTensor(np.reshape(batch, [len(batch), 1, batch[0].shape[0], batch[0].shape[1]])).to(device)
def datagen(face_path, frames, mels, poses, emotions, blinks, static=False, img_size=256, pads=[0, 0, 0, 0]):
img_batch, mel_batch, pose_batch, emotion_batch, blink_batch, frame_batch, coords_batch = [], [], [], [], [], [], []
scale_factor = img_size // 128
# print("Length of mel chunks: {}".format(len(mel_chunks)))
frames = frames[: len(mels)]
frames = add_black(frames)
try:
video_name = os.path.basename(face_path).split(".")[0]
coords = load_from_npz(video_name)
face_det_results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(frames, coords)]
except Exception as e:
print("No existing coords found, running face detection...", "Error: ", e)
if not static:
coords = face_detect(frames, pads)
face_det_results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(frames, coords)]
else:
coords = face_detect([frames[0]], pads)
face_det_results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(frames, coords)]
face_det_results = face_det_results[: len(mels)]
while len(frames) < len(mels):
face_det_results = face_det_results + face_det_results[::-1]
frames = frames + frames[::-1]
else:
face_det_results = face_det_results[: len(mels)]
frames = frames[: len(mels)]
for i in range(len(mels)):
idx = 0 if static else i % len(frames)
frame_to_save = frames[idx].copy()
face, coords = face_det_results[idx].copy()
face = cv2.resize(face, (img_size, img_size))
img_batch.append(face)
mel_batch.append(mels[i])
pose_batch.append(poses[i])
emotion_batch.append(emotions[i])
blink_batch.append(blinks[i])
frame_batch.append(frame_to_save)
coords_batch.append(coords)
# print(m.shape, poses[i].shape)
# (80, 16) (3, 16)
if len(img_batch) >= batch_size:
img_masked = np.asarray(img_batch).copy()
img_masked[:, 16 * scale_factor : -16 * scale_factor, 16 * scale_factor : -16 * scale_factor] = 0.0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
mel_batch = preprocess_batch(mel_batch)
pose_batch = preprocess_batch(pose_batch)
emotion_batch = preprocess_batch(emotion_batch)
blink_batch = preprocess_batch(blink_batch)
if use_fp16:
yield (
img_batch.half(),
mel_batch.half(),
pose_batch.half(),
emotion_batch.half(),
blink_batch.half(),
), frame_batch, coords_batch
else:
yield (img_batch, mel_batch, pose_batch, emotion_batch, blink_batch), frame_batch, coords_batch
img_batch, mel_batch, pose_batch, emotion_batch, blink_batch, frame_batch, coords_batch = [], [], [], [], [], [], []
if len(img_batch) > 0:
img_masked = np.asarray(img_batch).copy()
img_masked[:, 16 * scale_factor : -16 * scale_factor, 16 * scale_factor : -16 * scale_factor] = 0.0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
mel_batch = preprocess_batch(mel_batch)
pose_batch = preprocess_batch(pose_batch)
emotion_batch = preprocess_batch(emotion_batch)
blink_batch = preprocess_batch(blink_batch)
if use_fp16:
yield (img_batch.half(), mel_batch.half(), pose_batch.half(), emotion_batch.half(), blink_batch.half()), frame_batch, coords_batch
else:
yield (img_batch, mel_batch, pose_batch, emotion_batch, blink_batch), frame_batch, coords_batch
def infenrece(model, face_path, audio_path, pose, emotion, blink, preview=False):
timestamp = time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime(time.time()))
gen, steps = prepare_data(face_path, audio_path, pose, emotion, blink)
steps = 1 if preview else steps
# duration = librosa.get_duration(filename=audio_path)
if preview:
outfile = "/tmp/{}.jpg".format(timestamp)
else:
outfile = "/tmp/{}.mp4".format(timestamp)
tmp_video = "/tmp/temp_{}.mp4".format(timestamp)
writer = (
imageio.get_writer(tmp_video, fps=fps, codec="libx264", quality=10, pixelformat="yuv420p", macro_block_size=1)
if not preview
else None
)
# print('Generating frames...', outfile, steps)
for inputs, frames, coords in tqdm(gen, total=steps):
with torch.no_grad():
pred = model(*inputs)
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.0
for p, f, c in zip(pred, frames, coords):
y1, y2, x1, x2 = c
y1, y2, x1, x2 = int(y1), int(y2), int(x1), int(x2)
y = round(y2 - y1)
x = round(x2 - x1)
p = cv2.resize(p.astype(np.uint8), (x, y))
try:
f[y1 : y1 + y, x1 : x1 + x] = blend_images(f[y1 : y1 + y, x1 : x1 + x], p)
except Exception as e:
print(e)
f[y1 : y1 + y, x1 : x1 + x] = p
f = remove_black(f)
if preview:
cv2.imwrite(outfile, f, [int(cv2.IMWRITE_JPEG_QUALITY), 95])
return outfile
writer.append_data(cv2.cvtColor(f, cv2.COLOR_BGR2RGB))
writer.close()
video_clip = mp.VideoFileClip(tmp_video)
audio_clip = mp.AudioFileClip(audio_path)
video_clip = video_clip.set_audio(audio_clip)
video_clip.write_videofile(outfile, codec="libx264")
print("Saved to {}".format(outfile) if os.path.exists(outfile) else "Failed to save {}".format(outfile))
try:
os.remove(tmp_video)
del video_clip
del audio_clip
del gen
except:
pass
return outfile
if __name__ == "__main__":
model = init_model()
from attributtes_utils import input_pose, input_emotion, input_blink
pose = input_pose()
emotion = input_emotion()
blink = input_blink()
audio_path = "./assets/sample.wav"
face_path = "./assets/sample.mp4"
infenrece(model, face_path, audio_path, pose, emotion, blink)