diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..3774d0746a200107fe429239b59afca6cf67d844 Binary files /dev/null and b/.DS_Store differ diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..623966df1a8954fe4887ba1f343c13894d6a4332 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +Asian_Women_correct.png filter=lfs diff=lfs merge=lfs -text +avatar.png filter=lfs diff=lfs merge=lfs -text +driver_video.mp4 filter=lfs diff=lfs merge=lfs -text +tortoise/voices/train_lescault/lescault_new4.wav filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..983f46d5a7f13fb1dc59088df45ef7936fbe319d --- /dev/null +++ b/.gitignore @@ -0,0 +1,11 @@ +*.pth +*.avi +*.png +__pycache__ +__pycache__/* +.DS_Store +temp/* +results/* +*.lock +!assets/*.mp4 +!tortoise/data/mel_norms.pth \ No newline at end of file diff --git a/Asian_Women_correct.png b/Asian_Women_correct.png new file mode 100644 index 0000000000000000000000000000000000000000..b8f25a57b98a1772b02cdf43fc848e5cc932bd37 --- /dev/null +++ b/Asian_Women_correct.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b07f9f345bf4e2064b0582f8d5caad2de7aec3e112b186fe59c29032f2a01ce6 +size 1574757 diff --git a/__pycache__/animate_face.cpython-310.pyc b/__pycache__/animate_face.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f730365882eb99ecf420cd174987e6a6567964a2 Binary files /dev/null and b/__pycache__/animate_face.cpython-310.pyc differ diff --git a/__pycache__/config.cpython-310.pyc b/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8ad143f779b710721548ad7fb7230a35245acbe Binary files /dev/null and b/__pycache__/config.cpython-310.pyc differ diff --git a/__pycache__/config.cpython-38.pyc b/__pycache__/config.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9d940adb4ad1a0a44b164df5c78c106ea94e65e Binary files /dev/null and b/__pycache__/config.cpython-38.pyc differ diff --git a/__pycache__/image.cpython-310.pyc b/__pycache__/image.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a1a7564318a1782b288f4b35667a070367afbfe Binary files /dev/null and b/__pycache__/image.cpython-310.pyc differ diff --git a/__pycache__/improve.cpython-310.pyc b/__pycache__/improve.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc847682423f7f3224c7ec7e10757b64b47225a6 Binary files /dev/null and b/__pycache__/improve.cpython-310.pyc differ diff --git a/__pycache__/lips.cpython-310.pyc b/__pycache__/lips.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89a23d717a77c40ba30f695bd7453f0a1acaa463 Binary files /dev/null and b/__pycache__/lips.cpython-310.pyc differ diff --git a/__pycache__/speech.cpython-310.pyc b/__pycache__/speech.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c24cfad555521461d328416086fa9317c62bdcd Binary files /dev/null and b/__pycache__/speech.cpython-310.pyc differ diff --git a/animate_face.py b/animate_face.py new file mode 100644 index 0000000000000000000000000000000000000000..0e342fb31369ce1b0786ace57dcccbafdf417415 --- /dev/null +++ b/animate_face.py @@ -0,0 +1,390 @@ +import os +import sys +import cv2 +import yaml +import imageio +import numpy as np +import torch +import torch.nn.functional as F +import subprocess, platform +from mutagen.wave import WAVE +from datetime import timedelta + +from face_vid2vid.sync_batchnorm.replicate import DataParallelWithCallback +from face_vid2vid.modules.generator import OcclusionAwareSPADEGenerator +from face_vid2vid.modules.keypoint_detector import KPDetector, HEEstimator +from face_vid2vid.animate import normalize_kp +from batch_face import RetinaFace + + +if sys.version_info[0] < 3: + raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7") + + +def load_checkpoints(config_path, checkpoint_path): + with open(config_path) as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + generator = OcclusionAwareSPADEGenerator(**config["model_params"]["generator_params"], **config["model_params"]["common_params"]) + # convert to half precision to speed up + generator.cuda().half() + + kp_detector = KPDetector(**config["model_params"]["kp_detector_params"], **config["model_params"]["common_params"]) + # the result will be wrong if converted to half precision, not sure why + kp_detector.cuda() # .half() + + he_estimator = HEEstimator(**config["model_params"]["he_estimator_params"], **config["model_params"]["common_params"]) + # the result will be wrong if converted to half precision, not sure why + he_estimator.cuda() # .half() + + print("Loading checkpoints") + checkpoint = torch.load(checkpoint_path) + + generator.load_state_dict(checkpoint["generator"]) + kp_detector.load_state_dict(checkpoint["kp_detector"]) + he_estimator.load_state_dict(checkpoint["he_estimator"]) + + generator = DataParallelWithCallback(generator) + kp_detector = DataParallelWithCallback(kp_detector) + he_estimator = DataParallelWithCallback(he_estimator) + + generator.eval() + kp_detector.eval() + he_estimator.eval() + print("Model successfully loaded!") + + return generator, kp_detector, he_estimator + + +def headpose_pred_to_degree(pred): + device = pred.device + idx_tensor = [idx for idx in range(66)] + idx_tensor = torch.FloatTensor(idx_tensor).to(device) + pred = F.softmax(pred, dim=1) + degree = torch.sum(pred * idx_tensor, axis=1) * 3 - 99 + + return degree + + +def get_rotation_matrix(yaw, pitch, roll): + yaw = yaw / 180 * 3.14 + pitch = pitch / 180 * 3.14 + roll = roll / 180 * 3.14 + + roll = roll.unsqueeze(1) + pitch = pitch.unsqueeze(1) + yaw = yaw.unsqueeze(1) + + pitch_mat = torch.cat( + [ + torch.ones_like(pitch), + torch.zeros_like(pitch), + torch.zeros_like(pitch), + torch.zeros_like(pitch), + torch.cos(pitch), + -torch.sin(pitch), + torch.zeros_like(pitch), + torch.sin(pitch), + torch.cos(pitch), + ], + dim=1, + ) + pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3) + + yaw_mat = torch.cat( + [ + torch.cos(yaw), + torch.zeros_like(yaw), + torch.sin(yaw), + torch.zeros_like(yaw), + torch.ones_like(yaw), + torch.zeros_like(yaw), + -torch.sin(yaw), + torch.zeros_like(yaw), + torch.cos(yaw), + ], + dim=1, + ) + yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3) + + roll_mat = torch.cat( + [ + torch.cos(roll), + -torch.sin(roll), + torch.zeros_like(roll), + torch.sin(roll), + torch.cos(roll), + torch.zeros_like(roll), + torch.zeros_like(roll), + torch.zeros_like(roll), + torch.ones_like(roll), + ], + dim=1, + ) + roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3) + + rot_mat = torch.einsum("bij,bjk,bkm->bim", pitch_mat, yaw_mat, roll_mat) + + return rot_mat + + +def keypoint_transformation(kp_canonical, he, estimate_jacobian=False, free_view=False, yaw=0, pitch=0, roll=0, output_coord=False): + kp = kp_canonical["value"] + if not free_view: + yaw, pitch, roll = he["yaw"], he["pitch"], he["roll"] + yaw = headpose_pred_to_degree(yaw) + pitch = headpose_pred_to_degree(pitch) + roll = headpose_pred_to_degree(roll) + else: + if yaw is not None: + yaw = torch.tensor([yaw]).cuda() + else: + yaw = he["yaw"] + yaw = headpose_pred_to_degree(yaw) + if pitch is not None: + pitch = torch.tensor([pitch]).cuda() + else: + pitch = he["pitch"] + pitch = headpose_pred_to_degree(pitch) + if roll is not None: + roll = torch.tensor([roll]).cuda() + else: + roll = he["roll"] + roll = headpose_pred_to_degree(roll) + + t, exp = he["t"], he["exp"] + + rot_mat = get_rotation_matrix(yaw, pitch, roll) + + # keypoint rotation + kp_rotated = torch.einsum("bmp,bkp->bkm", rot_mat, kp) + + # keypoint translation + t = t.unsqueeze_(1).repeat(1, kp.shape[1], 1) + kp_t = kp_rotated + t + + # add expression deviation + exp = exp.view(exp.shape[0], -1, 3) + kp_transformed = kp_t + exp + + if estimate_jacobian: + jacobian = kp_canonical["jacobian"] + jacobian_transformed = torch.einsum("bmp,bkps->bkms", rot_mat, jacobian) + else: + jacobian_transformed = None + + if output_coord: + return {"value": kp_transformed, "jacobian": jacobian_transformed}, { + "yaw": float(yaw.cpu().numpy()), + "pitch": float(pitch.cpu().numpy()), + "roll": float(roll.cpu().numpy()), + } + + return {"value": kp_transformed, "jacobian": jacobian_transformed} + + +def get_square_face(coords, image): + x1, y1, x2, y2 = coords + # expand the face region by 1.5 times + length = max(x2 - x1, y2 - y1) // 2 + x1 = x1 - length * 0.5 + x2 = x2 + length * 0.5 + y1 = y1 - length * 0.5 + y2 = y2 + length * 0.5 + + # get square image + center = (x1 + x2) // 2, (y1 + y2) // 2 + length = max(x2 - x1, y2 - y1) // 2 + x1 = max(int(round(center[0] - length)), 0) + x2 = min(int(round(center[0] + length)), image.shape[1]) + y1 = max(int(round(center[1] - length)), 0) + y2 = min(int(round(center[1] + length)), image.shape[0]) + return image[y1:y2, x1:x2] + + +def smooth_coord(last_coord, current_coord, smooth_factor=0.2): + change = np.array(current_coord) - np.array(last_coord) + # smooth the change to 0.1 times + change = change * smooth_factor + return (np.array(last_coord) + np.array(change)).astype(int).tolist() + + +class FaceAnimationClass: + def __init__(self, source_image_path=None, use_sr=False): + assert source_image_path is not None, "source_image_path is None, please set source_image_path" + config_path = os.path.join(os.path.dirname(__file__), "face_vid2vid/config/vox-256-spade.yaml") + # save to local cache to speed loading + checkpoint_path = os.path.join(os.path.expanduser("~"), ".cache/torch/hub/checkpoints/FaceMapping.pth.tar") + if not os.path.exists(checkpoint_path): + os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) + from gdown import download + file_id = "11ZgyjKI5OcB7klcsIdPpCCX38AIX8Soc" + download(id=file_id, output=checkpoint_path, quiet=False) + if use_sr: + from face_vid2vid.GPEN.face_enhancement import FaceEnhancement + + self.faceenhancer = FaceEnhancement( + size=256, model="GPEN-BFR-256", use_sr=False, sr_model="realesrnet_x2", channel_multiplier=1, narrow=0.5, use_facegan=True + ) + + # load checkpoints + self.generator, self.kp_detector, self.he_estimator = load_checkpoints(config_path=config_path, checkpoint_path=checkpoint_path) + source_image = cv2.cvtColor(cv2.imread(source_image_path), cv2.COLOR_RGB2BGR).astype(np.float32) / 255. + source_image = cv2.resize(source_image, (256, 256), interpolation=cv2.INTER_AREA) + source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2) + self.source = source.cuda() + + # initilize face detectors + self.face_detector = RetinaFace() + self.detect_interval = 8 + self.smooth_factor = 0.2 + + # load base frame and blank frame + self.base_frame = cv2.imread(source_image_path) if not use_sr else self.faceenhancer.process(cv2.imread(source_image_path))[0] + self.base_frame = cv2.resize(self.base_frame, (256, 256)) + self.blank_frame = np.ones(self.base_frame.shape, dtype=np.uint8) * 255 + cv2.putText(self.blank_frame, "Face not", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) + cv2.putText(self.blank_frame, "detected!", (50, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) + + # count for frame + self.n_frame = 0 + + # initilize variables + self.first_frame = True + self.last_coords = None + self.coords = None + self.use_sr = use_sr + self.kp_source = None + self.kp_driving_initial = None + + + def _conver_input_frame(self, frame): + frame = cv2.resize(frame, (256, 256), interpolation=cv2.INTER_NEAREST).astype(np.float32) / 255.0 + return torch.tensor(frame[np.newaxis]).permute(0, 3, 1, 2).cuda() + + def _process_first_frame(self, frame): + print("Processing first frame") + # function to process the first frame + faces = self.face_detector(frame, cv=True) + if len(faces) == 0: + raise ValueError("Face is not detected") + else: + self.coords = faces[0][0] + face = get_square_face(self.coords, frame) + self.last_coords = self.coords + + # get the keypoint and headpose from the source image + with torch.no_grad(): + self.kp_canonical = self.kp_detector(self.source) + self.he_source = self.he_estimator(self.source) + + face_input = self._conver_input_frame(face) + he_driving_initial = self.he_estimator(face_input) + self.kp_driving_initial, coordinates = keypoint_transformation(self.kp_canonical, he_driving_initial, output_coord=True) + self.kp_source = keypoint_transformation( + self.kp_canonical, self.he_source, free_view=True, yaw=coordinates["yaw"], pitch=coordinates["pitch"], roll=coordinates["roll"] + ) + + def _inference(self, frame): + # function to process the rest frames + with torch.no_grad(): + self.n_frame += 1 + if self.first_frame: + self._process_first_frame(frame) + self.first_frame = False + else: + pass + if self.n_frame % self.detect_interval == 0: + faces = self.face_detector(frame, cv=True) + if len(faces) == 0: + raise ValueError("Face is not detected") + else: + self.coords = faces[0][0] + self.coords = smooth_coord(self.last_coords, self.coords, self.smooth_factor) + face = get_square_face(self.coords, frame) + self.last_coords = self.coords + face_input = self._conver_input_frame(face) + + he_driving = self.he_estimator(face_input) + kp_driving = keypoint_transformation(self.kp_canonical, he_driving) + kp_norm = normalize_kp( + kp_source=self.kp_source, + kp_driving=kp_driving, + kp_driving_initial=self.kp_driving_initial, + use_relative_movement=True, + adapt_movement_scale=True, + ) + + out = self.generator(self.source, kp_source=self.kp_source, kp_driving=kp_norm, fp16=True) + image = np.transpose(out["prediction"].data.cpu().numpy(), [0, 2, 3, 1])[0] + image = (np.array(image).astype(np.float32) * 255).astype(np.uint8) + result = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + return face, result + + def inference(self, frame): + # function to inference, input frame, output cropped face and its result + try: + if frame is not None: + face, result = self._inference(frame) + if self.use_sr: + result, _, _ = self.faceenhancer.process(result) + result = cv2.resize(result, (256, 256)) + return face, result + except Exception as e: + print(e) + self.first_frame = True + self.n_frame = 0 + return self.blank_frame, self.base_frame + + +def get_audio_duration(audioPath): + audio = WAVE(audioPath) + duration = audio.info.length + return duration + +def seconds_to_hms(seconds): + seconds = int(seconds) + 1 + hms = str(timedelta(seconds=seconds)) + hms = hms.split(":") + hms = [f"0{h}" if len(h) == 1 else h for h in hms] + return ":".join(hms) + +def animate_face(path_id, audiofile, driverfile, imgfile, animatedfile): + from tqdm import tqdm + import time + faceanimation = FaceAnimationClass(source_image_path=os.path.join("temp", path_id, imgfile), use_sr=False) + + tmpfile = f"temp/{path_id}/tmp.mp4" + duration = get_audio_duration(os.path.join("temp", path_id, audiofile)) + print("duration of audio:", duration) + hms = seconds_to_hms(duration) + print("converted into hms:", hms) + command = f"ffmpeg -ss 00:00:00 -i {driverfile} -to {hms} -c copy {tmpfile}" + subprocess.call(command, shell=platform.system() != 'Windows') + + capture = cv2.VideoCapture(tmpfile) + fps = capture.get(cv2.CAP_PROP_FPS) + frames = [] + _, frame = capture.read() + while frame is not None: + frames.append(frame) + _, frame = capture.read() + capture.release() + + output_frames = [] + time_start = time.time() + for frame in tqdm(frames): + face, result = faceanimation.inference(frame) + # show = cv2.hconcat([cv2.resize(face, (result.shape[1], result.shape[0])), result]) + output_frames.append(result) + time_end = time.time() + print("Time cost: %.2f" % (time_end - time_start), "FPS: %.2f" % (len(frames) / (time_end - time_start))) + writer = imageio.get_writer(os.path.join("temp", path_id, animatedfile), fps=fps, quality=9, macro_block_size=1, + codec="libx264", pixelformat="yuv420p") + for frame in output_frames: + writer.append_data(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + # writer.append_data(frame) + writer.close() + + diff --git a/assets/christmas.mp4 b/assets/christmas.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..84124db2abce4d090af98074a7cc2358f0cc1414 Binary files /dev/null and b/assets/christmas.mp4 differ diff --git a/assets/norad.mp4 b/assets/norad.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..a9496802cc200eb19ae0b4402335b191c9ff866e Binary files /dev/null and b/assets/norad.mp4 differ diff --git a/avatar.png b/avatar.png new file mode 100644 index 0000000000000000000000000000000000000000..6a9167c7d82867469b32aeff8f0f643a38597679 --- /dev/null +++ b/avatar.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:19bfb1e115cc35f6113a912a1d1846a5521d659b89f70ee8e16645862dc5cf4d +size 1188973 diff --git a/avatar.py b/avatar.py new file mode 100644 index 0000000000000000000000000000000000000000..e532ae974ab28d527e5e6d601fd82f2c4b9b0d0a --- /dev/null +++ b/avatar.py @@ -0,0 +1,95 @@ +from config import * +from image import generate_image +import humanize +import datetime as dt +from argparse import ArgumentParser +import shutil + +import os +from animate_face import animate_face +import subprocess, platform + +avatar_description = "Young asian man, with short brunette hair, slightly smiling" + +def main(): + parser = ArgumentParser() + parser.add_argument("--image", default=imgfile, help="path to avatar file") + parser.add_argument("--path_id", default=str(int(time.time())), help="set the path id to use") + parser.add_argument("--pitch", default=1.0, help="change pitch of voice, 1.0 is original, higher number is higher pitch") + args = parser.parse_args() + tstart = time.time() + + ## SET PATH + path_id = args.path_id + path = os.path.join("temp", path_id) + os.makedirs(path, exist_ok=True) + + ## GENERATE AVATAR IMAGE + timage = "None" + if args.image == imgfile: + print("-----------------------------------------") + print("generating avatar image") + t1 = time.time() + generate_image(path_id, imgfile, f"hyperrealistic digital avatar, centered, \ + {avatar_description}, rim lighting, studio lighting, looking at the camera") + timage = humanize.naturaldelta(dt.timedelta(seconds=int(time.time() - t1))) + print("\ngenerating avatar:", timage) + else: + shutil.copyfile(args.image, os.path.join("temp", path_id, imgfile)) + + ## EXTRACT SPEECH FROM MP4 + print("-----------------------------------------") + print("extracting speech from mp4") + t2 = time.time() + wavoutfile = os.path.join(path, audiofile) + command = 'ffmpeg -i {} -acodec pcm_s16le -ar 44100 -ac 1 {}'.format(driverfile, wavoutfile) + subprocess.call(command, shell=platform.system() != 'Windows') + tspeech = humanize.naturaldelta(dt.timedelta(microseconds=int(time.time() - t2))) + print("\nextracting speech:", tspeech) + + ## ANIMATE AVATAR IMAGE + print("-----------------------------------------") + print("animating face with driver") + t3 = time.time() + # audiofile determines the length of the driver movie to trim + # driver movie is imposed on the image file to produce the animated file + animate_face(path_id, audiofile, driverfile, imgfile, animatedfile) + tanimate = humanize.naturaldelta(dt.timedelta(seconds=int(time.time() - t3))) + print("\nanimating face:", tanimate) + + ## CHANGING THE PITCH OF THE VOICE + print("-----------------------------------------") + print("changing pitch of voice") + t4 = time.time() + wavpitchedfile = os.path.join(path, "pitched.wav") + # command = 'ffmpeg -i {} -af "rubberband=pitch={}" {}'.format(wavoutfile, args.pitch, wavpitchedfile) + command = 'ffmpeg -i {} -af "asetrate=44100*{},aresample=44100,atempo=1/{}" {}'.format(wavoutfile, args.pitch, args.pitch, wavpitchedfile) + + subprocess.call(command, shell=platform.system() != 'Windows') + tpitch = humanize.naturaldelta(dt.timedelta(microseconds=int(time.time() - t4))) + print("\changing pitch:", tpitch) + + ## COMBINING ANIMATION WITH SPPECH + print("-----------------------------------------") + print("combining animation with speech") + t5 = time.time() + animatedoutfile = os.path.join(path, animatedfile) + finaloutfile = os.path.join("results", path_id + "_animated.mp4") + command = 'ffmpeg -i {} -i {} -c:v copy -map 0:v:0 -map 1:a:0 -shortest {}'.format(animatedoutfile, wavpitchedfile, finaloutfile) + subprocess.call(command, shell=platform.system() != 'Windows') + tcombi = humanize.naturaldelta(dt.timedelta(microseconds=int(time.time() - t5))) + print("\combining animation with speech:", tcombi) + + + print("done") + print("Overall timing") + print("--------------") + print("generating avatar image:", timage) + print("extracting speech from mp4:", tspeech) + print("animating face:", tanimate) + print("changing pitch of voice:", tpitch) + print("combining animation with speech:", tcombi) + print("total time:", humanize.naturaldelta(minimum_unit="microseconds", value=dt.timedelta(seconds=int(time.time() - tstart)))) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/config.py b/config.py new file mode 100644 index 0000000000000000000000000000000000000000..244d9f7936f88dd8aa4ab1a4f9124dc9dba5b808 --- /dev/null +++ b/config.py @@ -0,0 +1,36 @@ +import torch +import time +import os + +path_id = "" +checkpoint_path="wav2lip/wav2lip_gan.pth" +outfile="out.mp4" +audiofile="tmp.wav" +imgfile="avatar.png" +driverfile="face_vid2vid/assets/driver06.mp4" +animatedfile="animated.mp4" +static=False +fps=25 +pads=[0, 10, 0, 0] +face_det_batch_size=16 +wav2lip_batch_size=128 +resize_factor=0.5 +crop=[0, -1, 0, -1] +box=[-1, -1, -1, -1] +img_size = 96 +rotate=False +nosmooth=False +mel_step_size = 16 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +print('Using {} for inference.'.format(device)) + +import warnings +warnings.filterwarnings('ignore') + +def init_path_id(): + path_id = str(int(time.time())) + path = os.path.join("temp", path_id) + os.makedirs(path, exist_ok=True) + return path_id, path + + diff --git a/driver_video.mp4 b/driver_video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..718f614220f0d063a7cbf34ec0acdc1d020f9004 --- /dev/null +++ b/driver_video.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f6145b3dc938bd00a844435d9f4115fdc6d924df3d1f834dbf5608c27037ef28 +size 11159662 diff --git a/face_vid2vid/GPEN/README.md b/face_vid2vid/GPEN/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8e3ed8ffb2a0cce26f9860a5269c6752210b0386 --- /dev/null +++ b/face_vid2vid/GPEN/README.md @@ -0,0 +1,92 @@ +# GAN Prior Embedded Network for Blind Face Restoration in the Wild + +[Paper](https://arxiv.org/abs/2105.06070) | [Supplementary](https://www4.comp.polyu.edu.hk/~cslzhang/paper/GPEN-cvpr21-supp.pdf) | [Demo](https://vision.aliyun.com/experience/detail?spm=a211p3.14020179.J_7524944390.17.66cd4850wVDkUQ&tagName=facebody&children=EnhanceFace) + +[Tao Yang](https://cg.cs.tsinghua.edu.cn/people/~tyang)1, Peiran Ren1, Xuansong Xie1, [Lei Zhang](https://www4.comp.polyu.edu.hk/~cslzhang)1,2 +_1[DAMO Academy, Alibaba Group](https://damo.alibaba.com), Hangzhou, China_ +_2[Department of Computing, The Hong Kong Polytechnic University](http://www.comp.polyu.edu.hk), Hong Kong, China_ + +#### Face Restoration + + + + + + +#### Face Colorization + + + +#### Face Inpainting + + + +#### Conditional Image Synthesis (Seg2Face) + + + +## News +(2021-07-06) The training code will be released soon. Stay tuned. + +(2021-10-11) The Colab demo for GPEN is available now google colab logo. + +(2021-10-22) GPEN can now work with SR methods. A SR model trained by myself is provided. Replace it with your own model if necessary. + +## Usage + +![python](https://img.shields.io/badge/python-v3.7.4-green.svg?style=plastic) +![pytorch](https://img.shields.io/badge/pytorch-v1.7.0-green.svg?style=plastic) +![cuda](https://img.shields.io/badge/cuda-v10.2.89-green.svg?style=plastic) +![driver](https://img.shields.io/badge/driver-v460.73.01-green.svg?style=plastic) +![gcc](https://img.shields.io/badge/gcc-v7.5.0-green.svg?style=plastic) + +- Clone this repository: +```bash +git clone https://github.com/yangxy/GPEN.git +cd GPEN +``` +- Download RetinaFace model and our pre-trained model (not our best model due to commercial issues) and put them into ``weights/``. + + [RetinaFace-R50](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/RetinaFace-R50.pth) | [GPEN-BFR-512](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-512.pth) | [GPEN-BFR-512-D](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-512-D.pth) | [GPEN-BFR-256](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-256.pth) | [GPEN-Colorization-1024](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-Colorization-1024.pth) | [GPEN-Inpainting-1024](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-Inpainting-1024.pth) | [GPEN-Seg2face-512](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-Seg2face-512.pth) | [rrdb_realesrnet_psnr](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/rrdb_realesrnet_psnr.pth) + +- Restore face images: +```bash +python face_enhancement.py --model GPEN-BFR-512 --size 512 --channel_multiplier 2 --narrow 1 --use_sr --indir examples/imgs --outdir examples/outs-BFR +``` + +- Colorize faces: +```bash +python face_colorization.py +``` + +- Complete faces: +```bash +python face_inpainting.py +``` + +- Synthesize faces: +```bash +python segmentation2face.py +``` + +## Main idea + + +## Citation +If our work is useful for your research, please consider citing: + + @inproceedings{Yang2021GPEN, + title={GAN Prior Embedded Network for Blind Face Restoration in the Wild}, + author={Tao Yang, Peiran Ren, Xuansong Xie, and Lei Zhang}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, + year={2021} + } + +## License +© Alibaba, 2021. For academic and non-commercial use only. + +## Acknowledgments +We borrow some codes from [Pytorch_Retinaface](https://github.com/biubug6/Pytorch_Retinaface), [stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch), and [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN). + +## Contact +If you have any questions or suggestions about this paper, feel free to reach me at yangtao9009@gmail.com. diff --git a/face_vid2vid/GPEN/__init_paths.py b/face_vid2vid/GPEN/__init_paths.py new file mode 100644 index 0000000000000000000000000000000000000000..a959ea4e159c16e5095cc595812b3680c424d144 --- /dev/null +++ b/face_vid2vid/GPEN/__init_paths.py @@ -0,0 +1,21 @@ +''' +@paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021) +@author: yangxy (yangtao9009@gmail.com) +''' +import os.path as osp +import sys + +def add_path(path): + if path not in sys.path: + sys.path.insert(0, path) + +this_dir = osp.dirname(__file__) + +path = osp.join(this_dir, 'retinaface') +add_path(path) + +path = osp.join(this_dir, 'face_model') +add_path(path) + +path = osp.join(this_dir, 'sr_model') +add_path(path) \ No newline at end of file diff --git a/face_vid2vid/GPEN/align_faces.py b/face_vid2vid/GPEN/align_faces.py new file mode 100644 index 0000000000000000000000000000000000000000..062ffad56d05607fc10a2a8b398d03f491e1e6f5 --- /dev/null +++ b/face_vid2vid/GPEN/align_faces.py @@ -0,0 +1,236 @@ +# -*- coding: utf-8 -*- +""" +Created on Mon Apr 24 15:43:29 2017 +@author: zhaoy +""" +""" +@Modified by yangxy (yangtao9009@gmail.com) +""" +import cv2 +import numpy as np +from skimage import transform as trans + +# reference facial points, a list of coordinates (x,y) +REFERENCE_FACIAL_POINTS = [ + [30.29459953, 51.69630051], + [65.53179932, 51.50139999], + [48.02519989, 71.73660278], + [33.54930115, 92.3655014], + [62.72990036, 92.20410156], +] + +DEFAULT_CROP_SIZE = (96, 112) + + +def _umeyama(src, dst, estimate_scale=True, scale=1.0): + """Estimate N-D similarity transformation with or without scaling. + Parameters + ---------- + src : (M, N) array + Source coordinates. + dst : (M, N) array + Destination coordinates. + estimate_scale : bool + Whether to estimate scaling factor. + Returns + ------- + T : (N + 1, N + 1) + The homogeneous similarity transformation matrix. The matrix contains + NaN values only if the problem == not well-conditioned. + References + ---------- + .. [1] "Least-squares estimation of transformation parameters between two + point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573` + """ + + num = src.shape[0] + dim = src.shape[1] + + # Compute mean of src and dst. + src_mean = src.mean(axis=0) + dst_mean = dst.mean(axis=0) + + # Subtract mean from src and dst. + src_demean = src - src_mean + dst_demean = dst - dst_mean + + # Eq. (38). + A = dst_demean.T @ src_demean / num + + # Eq. (39). + d = np.ones((dim,), dtype=np.double) + if np.linalg.det(A) < 0: + d[dim - 1] = -1 + + T = np.eye(dim + 1, dtype=np.double) + + U, S, V = np.linalg.svd(A) + + # Eq. (40) and (43). + rank = np.linalg.matrix_rank(A) + if rank == 0: + return np.nan * T + elif rank == dim - 1: + if np.linalg.det(U) * np.linalg.det(V) > 0: + T[:dim, :dim] = U @ V + else: + s = d[dim - 1] + d[dim - 1] = -1 + T[:dim, :dim] = U @ np.diag(d) @ V + d[dim - 1] = s + else: + T[:dim, :dim] = U @ np.diag(d) @ V + + if estimate_scale: + # Eq. (41) and (42). + scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d) + else: + scale = scale + + T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T) + T[:dim, :dim] *= scale + + return T, scale + + +class FaceWarpException(Exception): + def __str__(self): + return "In File {}:{}".format(__file__, super.__str__(self)) + + +def get_reference_facial_points(output_size=None, inner_padding_factor=0.0, outer_padding=(0, 0), default_square=False): + tmp_5pts = np.array(REFERENCE_FACIAL_POINTS) + tmp_crop_size = np.array(DEFAULT_CROP_SIZE) + + # 0) make the inner region a square + if default_square: + size_diff = max(tmp_crop_size) - tmp_crop_size + tmp_5pts += size_diff / 2 + tmp_crop_size += size_diff + + if output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]: + print("output_size == DEFAULT_CROP_SIZE {}: return default reference points".format(tmp_crop_size)) + return tmp_5pts + + if inner_padding_factor == 0 and outer_padding == (0, 0): + if output_size is None: + print("No paddings to do: return default reference points") + return tmp_5pts + else: + raise FaceWarpException("No paddings to do, output_size must be None or {}".format(tmp_crop_size)) + + # check output size + if not (0 <= inner_padding_factor <= 1.0): + raise FaceWarpException("Not (0 <= inner_padding_factor <= 1.0)") + + if (inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None: + output_size = tmp_crop_size * (1 + inner_padding_factor * 2).astype(np.int32) + output_size += np.array(outer_padding) + print(" deduced from paddings, output_size = ", output_size) + + if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]): + raise FaceWarpException("Not (outer_padding[0] < output_size[0]" "and outer_padding[1] < output_size[1])") + + # 1) pad the inner region according inner_padding_factor + # print('---> STEP1: pad the inner region according inner_padding_factor') + if inner_padding_factor > 0: + size_diff = tmp_crop_size * inner_padding_factor * 2 + tmp_5pts += size_diff / 2 + tmp_crop_size += np.round(size_diff).astype(np.int32) + + # print(' crop_size = ', tmp_crop_size) + # print(' reference_5pts = ', tmp_5pts) + + # 2) resize the padded inner region + # print('---> STEP2: resize the padded inner region') + size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2 + # print(' crop_size = ', tmp_crop_size) + # print(' size_bf_outer_pad = ', size_bf_outer_pad) + + if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]: + raise FaceWarpException("Must have (output_size - outer_padding)" "= some_scale * (crop_size * (1.0 + inner_padding_factor)") + + scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0] + # print(' resize scale_factor = ', scale_factor) + tmp_5pts = tmp_5pts * scale_factor + # size_diff = tmp_crop_size * (scale_factor - min(scale_factor)) + # tmp_5pts = tmp_5pts + size_diff / 2 + tmp_crop_size = size_bf_outer_pad + # print(' crop_size = ', tmp_crop_size) + # print(' reference_5pts = ', tmp_5pts) + + # 3) add outer_padding to make output_size + reference_5point = tmp_5pts + np.array(outer_padding) + tmp_crop_size = output_size + # print('---> STEP3: add outer_padding to make output_size') + # print(' crop_size = ', tmp_crop_size) + # print(' reference_5pts = ', tmp_5pts) + # + # print('===> end get_reference_facial_points\n') + + return reference_5point + + +def get_affine_transform_matrix(src_pts, dst_pts): + tfm = np.float32([[1, 0, 0], [0, 1, 0]]) + n_pts = src_pts.shape[0] + ones = np.ones((n_pts, 1), src_pts.dtype) + src_pts_ = np.hstack([src_pts, ones]) + dst_pts_ = np.hstack([dst_pts, ones]) + + A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_) + + if rank == 3: + tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]]) + elif rank == 2: + tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]]) + + return tfm + + +def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type="smilarity"): # smilarity cv2_affine affine + if reference_pts is None: + if crop_size[0] == 96 and crop_size[1] == 112: + reference_pts = REFERENCE_FACIAL_POINTS + else: + default_square = False + inner_padding_factor = 0 + outer_padding = (0, 0) + output_size = crop_size + + reference_pts = get_reference_facial_points(output_size, inner_padding_factor, outer_padding, default_square) + ref_pts = np.float32(reference_pts) + ref_pts_shp = ref_pts.shape + if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: + raise FaceWarpException("reference_pts.shape must be (K,2) or (2,K) and K>2") + + if ref_pts_shp[0] == 2: + ref_pts = ref_pts.T + + src_pts = np.float32(facial_pts) + src_pts_shp = src_pts.shape + if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: + raise FaceWarpException("facial_pts.shape must be (K,2) or (2,K) and K>2") + + if src_pts_shp[0] == 2: + src_pts = src_pts.T + + if src_pts.shape != ref_pts.shape: + raise FaceWarpException("facial_pts and reference_pts must have the same shape") + + if align_type == "cv2_affine": + tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3]) + tfm_inv = cv2.getAffineTransform(ref_pts[0:3], src_pts[0:3]) + elif align_type == "affine": + tfm = get_affine_transform_matrix(src_pts, ref_pts) + tfm_inv = get_affine_transform_matrix(ref_pts, src_pts) + else: + params, scale = _umeyama(src_pts, ref_pts) + tfm = params[:2, :] + + params, _ = _umeyama(ref_pts, src_pts, False, scale=1.0 / scale) + tfm_inv = params[:2, :] + + face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]), flags=3) + + return face_img, tfm_inv diff --git a/face_vid2vid/GPEN/face_enhancement.py b/face_vid2vid/GPEN/face_enhancement.py new file mode 100644 index 0000000000000000000000000000000000000000..7e807fe9dfabc68e188e7a4368f432ae2743c999 --- /dev/null +++ b/face_vid2vid/GPEN/face_enhancement.py @@ -0,0 +1,160 @@ +""" +@paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021) +@author: yangxy (yangtao9009@gmail.com) +""" +import os +import cv2 +import glob +import time +import argparse +import numpy as np +from PIL import Image +from skimage import transform as tf + +import GPEN.__init_paths as init_paths +from GPEN.retinaface.retinaface_detection import RetinaFaceDetection +from GPEN.face_model.face_gan import FaceGAN +from GPEN.sr_model.real_esrnet import RealESRNet +from GPEN.align_faces import warp_and_crop_face, get_reference_facial_points + +def check_ckpts(model, sr_model): + # check if checkpoints are downloaded + try: + ckpts_folder = os.path.join(os.path.dirname(__file__), "weights") + if not os.path.exists(ckpts_folder): + print("Downloading checkpoints...") + from gdown import download_folder + file_id = "1epln5c8HW1QXfVz6444Fe0hG-vRNavi6" + download_folder(id=file_id, output=ckpts_folder, quiet=False, use_cookies=False) + else: + print("Checkpoints already downloaded, skipping...") + except Exception as e: + print(e) + raise Exception("Error while downloading checkpoints") + + +class FaceEnhancement(object): + def __init__(self, base_dir=os.path.dirname(__file__), size=512, model=None, use_sr=True, sr_model=None, channel_multiplier=2, narrow=1, use_facegan=True): + check_ckpts(model, sr_model) + + self.facedetector = RetinaFaceDetection(base_dir) + self.facegan = FaceGAN(base_dir, size, model, channel_multiplier, narrow) + self.srmodel = RealESRNet(base_dir, sr_model) + self.use_sr = use_sr + self.size = size + self.threshold = 0.9 + self.use_facegan = use_facegan + + # the mask for pasting restored faces back + self.mask = np.zeros((512, 512), np.float32) + cv2.rectangle(self.mask, (26, 26), (486, 486), (1, 1, 1), -1, cv2.LINE_AA) + self.mask = cv2.GaussianBlur(self.mask, (101, 101), 11) + self.mask = cv2.GaussianBlur(self.mask, (101, 101), 11) + + self.kernel = np.array(([0.0625, 0.125, 0.0625], [0.125, 0.25, 0.125], [0.0625, 0.125, 0.0625]), dtype="float32") + + # get the reference 5 landmarks position in the crop settings + default_square = True + inner_padding_factor = 0.25 + outer_padding = (0, 0) + self.reference_5pts = get_reference_facial_points((self.size, self.size), inner_padding_factor, outer_padding, default_square) + + def process(self, img): + if self.use_sr: + img_sr = self.srmodel.process(img) + if img_sr is not None: + img = cv2.resize(img, img_sr.shape[:2][::-1]) + + facebs, landms = self.facedetector.detect(img) + + orig_faces, enhanced_faces = [], [] + height, width = img.shape[:2] + full_mask = np.zeros((height, width), dtype=np.float32) + full_img = np.zeros(img.shape, dtype=np.uint8) + + for i, (faceb, facial5points) in enumerate(zip(facebs, landms)): + if faceb[4] < self.threshold: + continue + fh, fw = (faceb[3] - faceb[1]), (faceb[2] - faceb[0]) + + facial5points = np.reshape(facial5points, (2, 5)) + + of, tfm_inv = warp_and_crop_face(img, facial5points, reference_pts=self.reference_5pts, crop_size=(self.size, self.size)) + + # enhance the face + ef = self.facegan.process(of) if self.use_facegan else of + + orig_faces.append(of) + enhanced_faces.append(ef) + + tmp_mask = self.mask + tmp_mask = cv2.resize(tmp_mask, ef.shape[:2]) + tmp_mask = cv2.warpAffine(tmp_mask, tfm_inv, (width, height), flags=3) + + if min(fh, fw) < 100: # gaussian filter for small faces + ef = cv2.filter2D(ef, -1, self.kernel) + + tmp_img = cv2.warpAffine(ef, tfm_inv, (width, height), flags=3) + + mask = tmp_mask - full_mask + full_mask[np.where(mask > 0)] = tmp_mask[np.where(mask > 0)] + full_img[np.where(mask > 0)] = tmp_img[np.where(mask > 0)] + + full_mask = full_mask[:, :, np.newaxis] + if self.use_sr and img_sr is not None: + img = cv2.convertScaleAbs(img_sr * (1 - full_mask) + full_img * full_mask) + else: + img = cv2.convertScaleAbs(img * (1 - full_mask) + full_img * full_mask) + + return img, orig_faces, enhanced_faces + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="GPEN-BFR-512", help="GPEN model") + parser.add_argument("--size", type=int, default=512, help="resolution of GPEN") + parser.add_argument("--channel_multiplier", type=int, default=2, help="channel multiplier of GPEN") + parser.add_argument("--narrow", type=float, default=1, help="channel narrow scale") + parser.add_argument("--use_sr", action="store_true", help="use sr or not") + parser.add_argument("--sr_model", type=str, default="realesrnet_x2", help="SR model") + parser.add_argument("--sr_scale", type=int, default=2, help="SR scale") + parser.add_argument("--indir", type=str, default="examples/imgs", help="input folder") + parser.add_argument("--outdir", type=str, default="results/outs-BFR", help="output folder") + args = parser.parse_args() + + # model = {'name':'GPEN-BFR-512', 'size':512, 'channel_multiplier':2, 'narrow':1} + # model = {'name':'GPEN-BFR-256', 'size':256, 'channel_multiplier':1, 'narrow':0.5} + + os.makedirs(args.outdir, exist_ok=True) + + faceenhancer = FaceEnhancement( + size=args.size, + model=args.model, + use_sr=args.use_sr, + sr_model=args.sr_model, + channel_multiplier=args.channel_multiplier, + narrow=args.narrow, + ) + + files = sorted(glob.glob(os.path.join(args.indir, "*.*g"))) + for n, file in enumerate(files[:]): + filename = os.path.basename(file) + + im = cv2.imread(file, cv2.IMREAD_COLOR) # BGR + if not isinstance(im, np.ndarray): + print(filename, "error") + continue + # im = cv2.resize(im, (0,0), fx=2, fy=2) # optional + + img, orig_faces, enhanced_faces = faceenhancer.process(im) + + im = cv2.resize(im, img.shape[:2][::-1]) + cv2.imwrite(os.path.join(args.outdir, ".".join(filename.split(".")[:-1]) + "_COMP.jpg"), np.hstack((im, img))) + cv2.imwrite(os.path.join(args.outdir, ".".join(filename.split(".")[:-1]) + "_GPEN.jpg"), img) + + for m, (ef, of) in enumerate(zip(enhanced_faces, orig_faces)): + of = cv2.resize(of, ef.shape[:2]) + cv2.imwrite(os.path.join(args.outdir, ".".join(filename.split(".")[:-1]) + "_face%02d" % m + ".jpg"), np.hstack((of, ef))) + + if n % 10 == 0: + print(n, filename) diff --git a/face_vid2vid/GPEN/face_model/face_gan.py b/face_vid2vid/GPEN/face_model/face_gan.py new file mode 100644 index 0000000000000000000000000000000000000000..6946a51bf35c9e38fcee828a5702c9247233ccfc --- /dev/null +++ b/face_vid2vid/GPEN/face_model/face_gan.py @@ -0,0 +1,54 @@ +''' +@paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021) +@author: yangxy (yangtao9009@gmail.com) +''' +import torch +import os +import cv2 +import glob +import numpy as np +from torch import nn +import torch.nn.functional as F +from torchvision import transforms, utils +from model import FullGenerator +import torch + +class FaceGAN(object): + def __init__(self, base_dir='./', size=512, model=None, channel_multiplier=2, narrow=1, is_norm=True): + self.mfile = os.path.join(base_dir, 'weights', model+'.pth') + self.n_mlp = 8 + self.is_norm = is_norm + self.resolution = size + self.load_model(channel_multiplier, narrow) + + def load_model(self, channel_multiplier=2, narrow=1): + self.model = FullGenerator(self.resolution, 512, self.n_mlp, channel_multiplier, narrow=narrow).cuda() + pretrained_dict = torch.load(self.mfile) + self.model.load_state_dict(pretrained_dict) + self.model.eval() + + def process(self, img): + img = cv2.resize(img, (self.resolution, self.resolution)) + img_t = self.img2tensor(img) + + with torch.no_grad(): + out, __ = self.model(img_t) + + out = self.tensor2img(out) + + return out + + def img2tensor(self, img): + img_t = torch.from_numpy(img).cuda()/255. + if self.is_norm: + img_t = (img_t - 0.5) / 0.5 + img_t = img_t.permute(2, 0, 1).unsqueeze(0).flip(1) # BGR->RGB + return img_t + + def tensor2img(self, img_t, pmax=255.0, imtype=np.uint8): + if self.is_norm: + img_t = img_t * 0.5 + 0.5 + img_t = img_t.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR + img_np = np.clip(img_t.float().cpu().numpy(), 0, 1) * pmax + + return img_np.astype(imtype) diff --git a/face_vid2vid/GPEN/face_model/model.py b/face_vid2vid/GPEN/face_model/model.py new file mode 100644 index 0000000000000000000000000000000000000000..f2874c13ba3138ce8986240eac1548d2933e5f71 --- /dev/null +++ b/face_vid2vid/GPEN/face_model/model.py @@ -0,0 +1,736 @@ +''' +@paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021) +@author: yangxy (yangtao9009@gmail.com) +''' +import math +import random +import functools +import operator +import itertools + +import torch +from torch import nn +from torch.nn import functional as F +from torch.autograd import Function + +from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d + +class PixelNorm(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + + if k.ndim == 1: + k = k[None, :] * k[:, None] + + k /= k.sum() + + return k + + +class Upsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) * (factor ** 2) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) + + return out + + +class Downsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) + + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.register_buffer('kernel', kernel) + + self.pad = pad + + def forward(self, input): + out = upfirdn2d(input, self.kernel, pad=self.pad) + + return out + + +class EqualConv2d(nn.Module): + def __init__( + self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True + ): + super().__init__() + + self.weight = nn.Parameter( + torch.randn(out_channel, in_channel, kernel_size, kernel_size) + ) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + + else: + self.bias = None + + def forward(self, input): + out = F.conv2d( + input, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + def __init__( + self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None + ): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + + else: + out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' + ) + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + out = F.leaky_relu(input, negative_slope=self.negative_slope) + + return out * math.sqrt(2) + + +class ModulatedConv2d(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + demodulate=True, + upsample=False, + downsample=False, + blur_kernel=[1, 3, 3, 1], + ): + super().__init__() + + self.eps = 1e-8 + self.kernel_size = kernel_size + self.in_channel = in_channel + self.out_channel = out_channel + self.upsample = upsample + self.downsample = downsample + + if upsample: + factor = 2 + p = (len(blur_kernel) - factor) - (kernel_size - 1) + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + 1 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1)) + + fan_in = in_channel * kernel_size ** 2 + self.scale = 1 / math.sqrt(fan_in) + self.padding = kernel_size // 2 + + self.weight = nn.Parameter( + torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) + ) + + self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) + + self.demodulate = demodulate + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' + f'upsample={self.upsample}, downsample={self.downsample})' + ) + + def forward(self, input, style): + batch, in_channel, height, width = input.shape + + style = self.modulation(style).view(batch, 1, in_channel, 1, 1) + weight = self.scale * self.weight * style + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) + weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) + + weight = weight.view( + batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + + if self.upsample: + input = input.view(1, batch * in_channel, height, width) + weight = weight.view( + batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + weight = weight.transpose(1, 2).reshape( + batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size + ) + out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + out = self.blur(out) + + elif self.downsample: + input = self.blur(input) + _, _, height, width = input.shape + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + else: + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=self.padding, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + return out + + +class NoiseInjection(nn.Module): + def __init__(self, isconcat=True): + super().__init__() + + self.isconcat = isconcat + self.weight = nn.Parameter(torch.zeros(1)) + + def forward(self, image, noise=None): + if noise is None: + batch, _, height, width = image.shape + noise = image.new_empty(batch, 1, height, width).normal_() + + if self.isconcat: + return torch.cat((image, self.weight * noise), dim=1) + else: + return image + self.weight * noise + + +class ConstantInput(nn.Module): + def __init__(self, channel, size=4): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, channel, size, size)) + + def forward(self, input): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1, 1) + + return out + + +class StyledConv(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=False, + blur_kernel=[1, 3, 3, 1], + demodulate=True, + isconcat=True + ): + super().__init__() + + self.conv = ModulatedConv2d( + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=upsample, + blur_kernel=blur_kernel, + demodulate=demodulate, + ) + + self.noise = NoiseInjection(isconcat) + #self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) + #self.activate = ScaledLeakyReLU(0.2) + feat_multiplier = 2 if isconcat else 1 + self.activate = FusedLeakyReLU(out_channel*feat_multiplier) + + def forward(self, input, style, noise=None): + out = self.conv(input, style) + out = self.noise(out, noise=noise) + # out = out + self.bias + out = self.activate(out) + + return out + + +class ToRGB(nn.Module): + def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + if upsample: + self.upsample = Upsample(blur_kernel) + + self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, input, style, skip=None): + out = self.conv(input, style) + out = out + self.bias + + if skip is not None: + skip = self.upsample(skip) + + out = out + skip + + return out + +class Generator(nn.Module): + def __init__( + self, + size, + style_dim, + n_mlp, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + isconcat=True, + narrow=1 + ): + super().__init__() + + self.size = size + self.n_mlp = n_mlp + self.style_dim = style_dim + self.feat_multiplier = 2 if isconcat else 1 + + layers = [PixelNorm()] + + for i in range(n_mlp): + layers.append( + EqualLinear( + style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu' + ) + ) + + self.style = nn.Sequential(*layers) + + self.channels = { + 4: int(512 * narrow), + 8: int(512 * narrow), + 16: int(512 * narrow), + 32: int(512 * narrow), + 64: int(256 * channel_multiplier * narrow), + 128: int(128 * channel_multiplier * narrow), + 256: int(64 * channel_multiplier * narrow), + 512: int(32 * channel_multiplier * narrow), + 1024: int(16 * channel_multiplier * narrow) + } + + self.input = ConstantInput(self.channels[4]) + self.conv1 = StyledConv( + self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel, isconcat=isconcat + ) + self.to_rgb1 = ToRGB(self.channels[4]*self.feat_multiplier, style_dim, upsample=False) + + self.log_size = int(math.log(size, 2)) + + self.convs = nn.ModuleList() + self.upsamples = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + + in_channel = self.channels[4] + + for i in range(3, self.log_size + 1): + out_channel = self.channels[2 ** i] + + self.convs.append( + StyledConv( + in_channel*self.feat_multiplier, + out_channel, + 3, + style_dim, + upsample=True, + blur_kernel=blur_kernel, + isconcat=isconcat + ) + ) + + self.convs.append( + StyledConv( + out_channel*self.feat_multiplier, out_channel, 3, style_dim, blur_kernel=blur_kernel, isconcat=isconcat + ) + ) + + self.to_rgbs.append(ToRGB(out_channel*self.feat_multiplier, style_dim)) + + in_channel = out_channel + + self.n_latent = self.log_size * 2 - 2 + + def make_noise(self): + device = self.input.input.device + + noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) + + return noises + + def mean_latent(self, n_latent): + latent_in = torch.randn( + n_latent, self.style_dim, device=self.input.input.device + ) + latent = self.style(latent_in).mean(0, keepdim=True) + + return latent + + def get_latent(self, input): + return self.style(input) + + def forward( + self, + styles, + return_latents=False, + inject_index=None, + truncation=1, + truncation_latent=None, + input_is_latent=False, + noise=None, + ): + if not input_is_latent: + styles = [self.style(s) for s in styles] + + if noise is None: + ''' + noise = [None] * (2 * (self.log_size - 2) + 1) + ''' + noise = [] + batch = styles[0].shape[0] + for i in range(self.n_mlp + 1): + size = 2 ** (i+2) + noise.append(torch.randn(batch, self.channels[size], size, size, device=styles[0].device)) + + if truncation < 1: + style_t = [] + + for style in styles: + style_t.append( + truncation_latent + truncation * (style - truncation_latent) + ) + + styles = style_t + + if len(styles) < 2: + inject_index = self.n_latent + + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + + else: + if inject_index is None: + inject_index = random.randint(1, self.n_latent - 1) + + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) + + latent = torch.cat([latent, latent2], 1) + + out = self.input(latent) + out = self.conv1(out, latent[:, 0], noise=noise[0]) + + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs + ): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + + i += 2 + + image = skip + + if return_latents: + return image, latent + + else: + return image, None + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append( + EqualConv2d( + in_channel, + out_channel, + kernel_size, + padding=self.padding, + stride=stride, + bias=bias and not activate, + ) + ) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer( + in_channel, out_channel, 1, downsample=True, activate=False, bias=False + ) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + +class FullGenerator(nn.Module): + def __init__( + self, + size, + style_dim, + n_mlp, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + isconcat=True, + narrow=1 + ): + super().__init__() + channels = { + 4: int(512 * narrow), + 8: int(512 * narrow), + 16: int(512 * narrow), + 32: int(512 * narrow), + 64: int(256 * channel_multiplier * narrow), + 128: int(128 * channel_multiplier * narrow), + 256: int(64 * channel_multiplier * narrow), + 512: int(32 * channel_multiplier * narrow), + 1024: int(16 * channel_multiplier * narrow) + } + + self.log_size = int(math.log(size, 2)) + self.generator = Generator(size, style_dim, n_mlp, channel_multiplier=channel_multiplier, blur_kernel=blur_kernel, lr_mlp=lr_mlp, isconcat=isconcat, narrow=narrow) + + conv = [ConvLayer(3, channels[size], 1)] + self.ecd0 = nn.Sequential(*conv) + in_channel = channels[size] + + self.names = ['ecd%d'%i for i in range(self.log_size-1)] + for i in range(self.log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + #conv = [ResBlock(in_channel, out_channel, blur_kernel)] + conv = [ConvLayer(in_channel, out_channel, 3, downsample=True)] + setattr(self, self.names[self.log_size-i+1], nn.Sequential(*conv)) + in_channel = out_channel + self.final_linear = nn.Sequential(EqualLinear(channels[4] * 4 * 4, style_dim, activation='fused_lrelu')) + + def forward(self, + inputs, + return_latents=False, + inject_index=None, + truncation=1, + truncation_latent=None, + input_is_latent=False, + ): + noise = [] + for i in range(self.log_size-1): + ecd = getattr(self, self.names[i]) + inputs = ecd(inputs) + noise.append(inputs) + #print(inputs.shape) + inputs = inputs.view(inputs.shape[0], -1) + outs = self.final_linear(inputs) + #print(outs.shape) + noise = list(itertools.chain.from_iterable(itertools.repeat(x, 2) for x in noise))[::-1] + outs = self.generator([outs], return_latents, inject_index, truncation, truncation_latent, input_is_latent, noise=noise[1:]) + return outs + +class Discriminator(nn.Module): + def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], narrow=1): + super().__init__() + + channels = { + 4: int(512 * narrow), + 8: int(512 * narrow), + 16: int(512 * narrow), + 32: int(512 * narrow), + 64: int(256 * channel_multiplier * narrow), + 128: int(128 * channel_multiplier * narrow), + 256: int(64 * channel_multiplier * narrow), + 512: int(32 * channel_multiplier * narrow), + 1024: int(16 * channel_multiplier * narrow) + } + + convs = [ConvLayer(3, channels[size], 1)] + + log_size = int(math.log(size, 2)) + + in_channel = channels[size] + + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + + convs.append(ResBlock(in_channel, out_channel, blur_kernel)) + + in_channel = out_channel + + self.convs = nn.Sequential(*convs) + + self.stddev_group = 4 + self.stddev_feat = 1 + + self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) + self.final_linear = nn.Sequential( + EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), + EqualLinear(channels[4], 1), + ) + + def forward(self, input): + out = self.convs(input) + + batch, channel, height, width = out.shape + group = min(batch, self.stddev_group) + stddev = out.view( + group, -1, self.stddev_feat, channel // self.stddev_feat, height, width + ) + stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) + stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) + stddev = stddev.repeat(group, 1, height, width) + out = torch.cat([out, stddev], 1) + + out = self.final_conv(out) + + out = out.view(batch, -1) + out = self.final_linear(out) + return out diff --git a/face_vid2vid/GPEN/face_model/op/__init__.py b/face_vid2vid/GPEN/face_model/op/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d0918d92285955855be89f00096b888ee5597ce3 --- /dev/null +++ b/face_vid2vid/GPEN/face_model/op/__init__.py @@ -0,0 +1,2 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu +from .upfirdn2d import upfirdn2d diff --git a/face_vid2vid/GPEN/face_model/op/fused_act.py b/face_vid2vid/GPEN/face_model/op/fused_act.py new file mode 100644 index 0000000000000000000000000000000000000000..d4cfc1b55fcff8ff9ab59c143e857423ff964caa --- /dev/null +++ b/face_vid2vid/GPEN/face_model/op/fused_act.py @@ -0,0 +1,88 @@ +import os + +import torch +from torch import nn +from torch.autograd import Function +from torch.utils.cpp_extension import load, _import_module_from_library + + +module_path = os.path.dirname(__file__) +fused = load( + 'fused', + sources=[ + os.path.join(module_path, 'fused_bias_act.cpp'), + os.path.join(module_path, 'fused_bias_act_kernel.cu'), + ], +) + +#fused = _import_module_from_library('fused', '/tmp/torch_extensions/fused', True) + + +class FusedLeakyReLUFunctionBackward(Function): + @staticmethod + def forward(ctx, grad_output, out, negative_slope, scale): + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + empty = grad_output.new_empty(0) + + grad_input = fused.fused_bias_act( + grad_output, empty, out, 3, 1, negative_slope, scale + ) + + dim = [0] + + if grad_input.ndim > 2: + dim += list(range(2, grad_input.ndim)) + + grad_bias = grad_input.sum(dim).detach() + + return grad_input, grad_bias + + @staticmethod + def backward(ctx, gradgrad_input, gradgrad_bias): + out, = ctx.saved_tensors + gradgrad_out = fused.fused_bias_act( + gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale + ) + + return gradgrad_out, None, None, None + + +class FusedLeakyReLUFunction(Function): + @staticmethod + def forward(ctx, input, bias, negative_slope, scale): + empty = input.new_empty(0) + out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + return out + + @staticmethod + def backward(ctx, grad_output): + out, = ctx.saved_tensors + + grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( + grad_output, out, ctx.negative_slope, ctx.scale + ) + + return grad_input, grad_bias, None, None + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + + self.bias = nn.Parameter(torch.zeros(channel)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) diff --git a/face_vid2vid/GPEN/face_model/op/fused_bias_act.cpp b/face_vid2vid/GPEN/face_model/op/fused_bias_act.cpp new file mode 100644 index 0000000000000000000000000000000000000000..02be898f970bcc8ea297867fcaa4e71b24b3d949 --- /dev/null +++ b/face_vid2vid/GPEN/face_model/op/fused_bias_act.cpp @@ -0,0 +1,21 @@ +#include + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + CHECK_CUDA(input); + CHECK_CUDA(bias); + + return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); +} \ No newline at end of file diff --git a/face_vid2vid/GPEN/face_model/op/fused_bias_act_kernel.cu b/face_vid2vid/GPEN/face_model/op/fused_bias_act_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..c9fa56fea7ede7072dc8925cfb0148f136eb85b8 --- /dev/null +++ b/face_vid2vid/GPEN/face_model/op/fused_bias_act_kernel.cu @@ -0,0 +1,99 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +template +static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, + int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { + int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; + + scalar_t zero = 0.0; + + for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { + scalar_t x = p_x[xi]; + + if (use_bias) { + x += p_b[(xi / step_b) % size_b]; + } + + scalar_t ref = use_ref ? p_ref[xi] : zero; + + scalar_t y; + + switch (act * 10 + grad) { + default: + case 10: y = x; break; + case 11: y = x; break; + case 12: y = 0.0; break; + + case 30: y = (x > 0.0) ? x : x * alpha; break; + case 31: y = (ref > 0.0) ? x : x * alpha; break; + case 32: y = 0.0; break; + } + + out[xi] = y * scale; + } +} + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + auto x = input.contiguous(); + auto b = bias.contiguous(); + auto ref = refer.contiguous(); + + int use_bias = b.numel() ? 1 : 0; + int use_ref = ref.numel() ? 1 : 0; + + int size_x = x.numel(); + int size_b = b.numel(); + int step_b = 1; + + for (int i = 1 + 1; i < x.dim(); i++) { + step_b *= x.size(i); + } + + int loop_x = 4; + int block_size = 4 * 32; + int grid_size = (size_x - 1) / (loop_x * block_size) + 1; + + auto y = torch::empty_like(x); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { + fused_bias_act_kernel<<>>( + y.data_ptr(), + x.data_ptr(), + b.data_ptr(), + ref.data_ptr(), + act, + grad, + alpha, + scale, + loop_x, + size_x, + step_b, + size_b, + use_bias, + use_ref + ); + }); + + return y; +} \ No newline at end of file diff --git a/face_vid2vid/GPEN/face_model/op/upfirdn2d.cpp b/face_vid2vid/GPEN/face_model/op/upfirdn2d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d2e633dc896433c205e18bc3e455539192ff968e --- /dev/null +++ b/face_vid2vid/GPEN/face_model/op/upfirdn2d.cpp @@ -0,0 +1,23 @@ +#include + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + CHECK_CUDA(input); + CHECK_CUDA(kernel); + + return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); +} \ No newline at end of file diff --git a/face_vid2vid/GPEN/face_model/op/upfirdn2d.py b/face_vid2vid/GPEN/face_model/op/upfirdn2d.py new file mode 100644 index 0000000000000000000000000000000000000000..d87dcf325850fc64204b7157c7ee1914110f6215 --- /dev/null +++ b/face_vid2vid/GPEN/face_model/op/upfirdn2d.py @@ -0,0 +1,188 @@ +import os + +import torch +from torch.autograd import Function +from torch.utils.cpp_extension import load, _import_module_from_library + + +module_path = os.path.dirname(__file__) +upfirdn2d_op = load( + 'upfirdn2d', + sources=[ + os.path.join(module_path, 'upfirdn2d.cpp'), + os.path.join(module_path, 'upfirdn2d_kernel.cu'), + ], +) + +#upfirdn2d_op = _import_module_from_library('upfirdn2d', '/tmp/torch_extensions/upfirdn2d', True) + +class UpFirDn2dBackward(Function): + @staticmethod + def forward( + ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size + ): + + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_op.upfirdn2d( + grad_output, + grad_kernel, + down_x, + down_y, + up_x, + up_y, + g_pad_x0, + g_pad_x1, + g_pad_y0, + g_pad_y1, + ) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + kernel, = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_op.upfirdn2d( + gradgrad_input, + kernel, + ctx.up_x, + ctx.up_y, + ctx.down_x, + ctx.down_y, + ctx.pad_x0, + ctx.pad_x1, + ctx.pad_y0, + ctx.pad_y1, + ) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view( + ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] + ) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + batch, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_op.upfirdn2d( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 + ) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + out = UpFirDn2d.apply( + input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) + ) + + return out + + +def upfirdn2d_native( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 +): + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad( + out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] + ) + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] + ) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + + return out[:, ::down_y, ::down_x, :] + diff --git a/face_vid2vid/GPEN/face_model/op/upfirdn2d_kernel.cu b/face_vid2vid/GPEN/face_model/op/upfirdn2d_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..2a710aa6adc3d43ac93136a1814e3c39970e1c7e --- /dev/null +++ b/face_vid2vid/GPEN/face_model/op/upfirdn2d_kernel.cu @@ -0,0 +1,272 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +static __host__ __device__ __forceinline__ int floor_div(int a, int b) { + int c = a / b; + + if (c * b > a) { + c--; + } + + return c; +} + + +struct UpFirDn2DKernelParams { + int up_x; + int up_y; + int down_x; + int down_y; + int pad_x0; + int pad_x1; + int pad_y0; + int pad_y1; + + int major_dim; + int in_h; + int in_w; + int minor_dim; + int kernel_h; + int kernel_w; + int out_h; + int out_w; + int loop_major; + int loop_x; +}; + + +template +__global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) { + const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; + const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; + + __shared__ volatile float sk[kernel_h][kernel_w]; + __shared__ volatile float sx[tile_in_h][tile_in_w]; + + int minor_idx = blockIdx.x; + int tile_out_y = minor_idx / p.minor_dim; + minor_idx -= tile_out_y * p.minor_dim; + tile_out_y *= tile_out_h; + int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; + int major_idx_base = blockIdx.z * p.loop_major; + + if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) { + return; + } + + for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) { + int ky = tap_idx / kernel_w; + int kx = tap_idx - ky * kernel_w; + scalar_t v = 0.0; + + if (kx < p.kernel_w & ky < p.kernel_h) { + v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; + } + + sk[ky][kx] = v; + } + + for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) { + for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) { + int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; + int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; + int tile_in_x = floor_div(tile_mid_x, up_x); + int tile_in_y = floor_div(tile_mid_y, up_y); + + __syncthreads(); + + for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) { + int rel_in_y = in_idx / tile_in_w; + int rel_in_x = in_idx - rel_in_y * tile_in_w; + int in_x = rel_in_x + tile_in_x; + int in_y = rel_in_y + tile_in_y; + + scalar_t v = 0.0; + + if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { + v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; + } + + sx[rel_in_y][rel_in_x] = v; + } + + __syncthreads(); + for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) { + int rel_out_y = out_idx / tile_out_w; + int rel_out_x = out_idx - rel_out_y * tile_out_w; + int out_x = rel_out_x + tile_out_x; + int out_y = rel_out_y + tile_out_y; + + int mid_x = tile_mid_x + rel_out_x * down_x; + int mid_y = tile_mid_y + rel_out_y * down_y; + int in_x = floor_div(mid_x, up_x); + int in_y = floor_div(mid_y, up_y); + int rel_in_x = in_x - tile_in_x; + int rel_in_y = in_y - tile_in_y; + int kernel_x = (in_x + 1) * up_x - mid_x - 1; + int kernel_y = (in_y + 1) * up_y - mid_y - 1; + + scalar_t v = 0.0; + + #pragma unroll + for (int y = 0; y < kernel_h / up_y; y++) + #pragma unroll + for (int x = 0; x < kernel_w / up_x; x++) + v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x]; + + if (out_x < p.out_w & out_y < p.out_h) { + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; + } + } + } + } +} + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + UpFirDn2DKernelParams p; + + auto x = input.contiguous(); + auto k = kernel.contiguous(); + + p.major_dim = x.size(0); + p.in_h = x.size(1); + p.in_w = x.size(2); + p.minor_dim = x.size(3); + p.kernel_h = k.size(0); + p.kernel_w = k.size(1); + p.up_x = up_x; + p.up_y = up_y; + p.down_x = down_x; + p.down_y = down_y; + p.pad_x0 = pad_x0; + p.pad_x1 = pad_x1; + p.pad_y0 = pad_y0; + p.pad_y1 = pad_y1; + + p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y; + p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x; + + auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); + + int mode = -1; + + int tile_out_h; + int tile_out_w; + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 1; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) { + mode = 2; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 3; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 4; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 5; + tile_out_h = 8; + tile_out_w = 32; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 6; + tile_out_h = 8; + tile_out_w = 32; + } + + dim3 block_size; + dim3 grid_size; + + if (tile_out_h > 0 && tile_out_w) { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 1; + block_size = dim3(32 * 8, 1, 1); + grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, + (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { + switch (mode) { + case 1: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 2: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 3: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 4: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 5: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + + case 6: + upfirdn2d_kernel<<>>( + out.data_ptr(), x.data_ptr(), k.data_ptr(), p + ); + + break; + } + }); + + return out; +} \ No newline at end of file diff --git a/face_vid2vid/GPEN/requirements.txt b/face_vid2vid/GPEN/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1d9d840896fc11e27c8aff8ac948a4930af4c2e1 --- /dev/null +++ b/face_vid2vid/GPEN/requirements.txt @@ -0,0 +1,8 @@ +ninja +torch +torchvision +opencv-python +numpy +scikit-image +pillow +pyyaml==5.4.1 \ No newline at end of file diff --git a/face_vid2vid/GPEN/retinaface/data/FDDB/img_list.txt b/face_vid2vid/GPEN/retinaface/data/FDDB/img_list.txt new file mode 100644 index 0000000000000000000000000000000000000000..5cf3d3199ca5c9c5ef4a904f1b9c89b821a7978a --- /dev/null +++ b/face_vid2vid/GPEN/retinaface/data/FDDB/img_list.txt @@ -0,0 +1,2845 @@ +2002/08/11/big/img_591 +2002/08/26/big/img_265 +2002/07/19/big/img_423 +2002/08/24/big/img_490 +2002/08/31/big/img_17676 +2002/07/31/big/img_228 +2002/07/24/big/img_402 +2002/08/04/big/img_769 +2002/07/19/big/img_581 +2002/08/13/big/img_723 +2002/08/12/big/img_821 +2003/01/17/big/img_610 +2002/08/13/big/img_1116 +2002/08/28/big/img_19238 +2002/08/21/big/img_660 +2002/08/14/big/img_607 +2002/08/05/big/img_3708 +2002/08/19/big/img_511 +2002/08/07/big/img_1316 +2002/07/25/big/img_1047 +2002/07/23/big/img_474 +2002/07/27/big/img_970 +2002/09/02/big/img_15752 +2002/09/01/big/img_16378 +2002/09/01/big/img_16189 +2002/08/26/big/img_276 +2002/07/24/big/img_518 +2002/08/14/big/img_1027 +2002/08/24/big/img_733 +2002/08/15/big/img_249 +2003/01/15/big/img_1371 +2002/08/07/big/img_1348 +2003/01/01/big/img_331 +2002/08/23/big/img_536 +2002/07/30/big/img_224 +2002/08/10/big/img_763 +2002/08/21/big/img_293 +2002/08/15/big/img_1211 +2002/08/15/big/img_1194 +2003/01/15/big/img_390 +2002/08/06/big/img_2893 +2002/08/17/big/img_691 +2002/08/07/big/img_1695 +2002/08/16/big/img_829 +2002/07/25/big/img_201 +2002/08/23/big/img_36 +2003/01/15/big/img_763 +2003/01/15/big/img_637 +2002/08/22/big/img_592 +2002/07/25/big/img_817 +2003/01/15/big/img_1219 +2002/08/05/big/img_3508 +2002/08/15/big/img_1108 +2002/07/19/big/img_488 +2003/01/16/big/img_704 +2003/01/13/big/img_1087 +2002/08/10/big/img_670 +2002/07/24/big/img_104 +2002/08/27/big/img_19823 +2002/09/01/big/img_16229 +2003/01/13/big/img_846 +2002/08/04/big/img_412 +2002/07/22/big/img_554 +2002/08/12/big/img_331 +2002/08/02/big/img_533 +2002/08/12/big/img_259 +2002/08/18/big/img_328 +2003/01/14/big/img_630 +2002/08/05/big/img_3541 +2002/08/06/big/img_2390 +2002/08/20/big/img_150 +2002/08/02/big/img_1231 +2002/08/16/big/img_710 +2002/08/19/big/img_591 +2002/07/22/big/img_725 +2002/07/24/big/img_820 +2003/01/13/big/img_568 +2002/08/22/big/img_853 +2002/08/09/big/img_648 +2002/08/23/big/img_528 +2003/01/14/big/img_888 +2002/08/30/big/img_18201 +2002/08/13/big/img_965 +2003/01/14/big/img_660 +2002/07/19/big/img_517 +2003/01/14/big/img_406 +2002/08/30/big/img_18433 +2002/08/07/big/img_1630 +2002/08/06/big/img_2717 +2002/08/21/big/img_470 +2002/07/23/big/img_633 +2002/08/20/big/img_915 +2002/08/16/big/img_893 +2002/07/29/big/img_644 +2002/08/15/big/img_529 +2002/08/16/big/img_668 +2002/08/07/big/img_1871 +2002/07/25/big/img_192 +2002/07/31/big/img_961 +2002/08/19/big/img_738 +2002/07/31/big/img_382 +2002/08/19/big/img_298 +2003/01/17/big/img_608 +2002/08/21/big/img_514 +2002/07/23/big/img_183 +2003/01/17/big/img_536 +2002/07/24/big/img_478 +2002/08/06/big/img_2997 +2002/09/02/big/img_15380 +2002/08/07/big/img_1153 +2002/07/31/big/img_967 +2002/07/31/big/img_711 +2002/08/26/big/img_664 +2003/01/01/big/img_326 +2002/08/24/big/img_775 +2002/08/08/big/img_961 +2002/08/16/big/img_77 +2002/08/12/big/img_296 +2002/07/22/big/img_905 +2003/01/13/big/img_284 +2002/08/13/big/img_887 +2002/08/24/big/img_849 +2002/07/30/big/img_345 +2002/08/18/big/img_419 +2002/08/01/big/img_1347 +2002/08/05/big/img_3670 +2002/07/21/big/img_479 +2002/08/08/big/img_913 +2002/09/02/big/img_15828 +2002/08/30/big/img_18194 +2002/08/08/big/img_471 +2002/08/22/big/img_734 +2002/08/09/big/img_586 +2002/08/09/big/img_454 +2002/07/29/big/img_47 +2002/07/19/big/img_381 +2002/07/29/big/img_733 +2002/08/20/big/img_327 +2002/07/21/big/img_96 +2002/08/06/big/img_2680 +2002/07/25/big/img_919 +2002/07/21/big/img_158 +2002/07/22/big/img_801 +2002/07/22/big/img_567 +2002/07/24/big/img_804 +2002/07/24/big/img_690 +2003/01/15/big/img_576 +2002/08/14/big/img_335 +2003/01/13/big/img_390 +2002/08/11/big/img_258 +2002/07/23/big/img_917 +2002/08/15/big/img_525 +2003/01/15/big/img_505 +2002/07/30/big/img_886 +2003/01/16/big/img_640 +2003/01/14/big/img_642 +2003/01/17/big/img_844 +2002/08/04/big/img_571 +2002/08/29/big/img_18702 +2003/01/15/big/img_240 +2002/07/29/big/img_553 +2002/08/10/big/img_354 +2002/08/18/big/img_17 +2003/01/15/big/img_782 +2002/07/27/big/img_382 +2002/08/14/big/img_970 +2003/01/16/big/img_70 +2003/01/16/big/img_625 +2002/08/18/big/img_341 +2002/08/26/big/img_188 +2002/08/09/big/img_405 +2002/08/02/big/img_37 +2002/08/13/big/img_748 +2002/07/22/big/img_399 +2002/07/25/big/img_844 +2002/08/12/big/img_340 +2003/01/13/big/img_815 +2002/08/26/big/img_5 +2002/08/10/big/img_158 +2002/08/18/big/img_95 +2002/07/29/big/img_1297 +2003/01/13/big/img_508 +2002/09/01/big/img_16680 +2003/01/16/big/img_338 +2002/08/13/big/img_517 +2002/07/22/big/img_626 +2002/08/06/big/img_3024 +2002/07/26/big/img_499 +2003/01/13/big/img_387 +2002/08/31/big/img_18025 +2002/08/13/big/img_520 +2003/01/16/big/img_576 +2002/07/26/big/img_121 +2002/08/25/big/img_703 +2002/08/26/big/img_615 +2002/08/17/big/img_434 +2002/08/02/big/img_677 +2002/08/18/big/img_276 +2002/08/05/big/img_3672 +2002/07/26/big/img_700 +2002/07/31/big/img_277 +2003/01/14/big/img_220 +2002/08/23/big/img_232 +2002/08/31/big/img_17422 +2002/07/22/big/img_508 +2002/08/13/big/img_681 +2003/01/15/big/img_638 +2002/08/30/big/img_18408 +2003/01/14/big/img_533 +2003/01/17/big/img_12 +2002/08/28/big/img_19388 +2002/08/08/big/img_133 +2002/07/26/big/img_885 +2002/08/19/big/img_387 +2002/08/27/big/img_19976 +2002/08/26/big/img_118 +2002/08/28/big/img_19146 +2002/08/05/big/img_3259 +2002/08/15/big/img_536 +2002/07/22/big/img_279 +2002/07/22/big/img_9 +2002/08/13/big/img_301 +2002/08/15/big/img_974 +2002/08/06/big/img_2355 +2002/08/01/big/img_1526 +2002/08/03/big/img_417 +2002/08/04/big/img_407 +2002/08/15/big/img_1029 +2002/07/29/big/img_700 +2002/08/01/big/img_1463 +2002/08/31/big/img_17365 +2002/07/28/big/img_223 +2002/07/19/big/img_827 +2002/07/27/big/img_531 +2002/07/19/big/img_845 +2002/08/20/big/img_382 +2002/07/31/big/img_268 +2002/08/27/big/img_19705 +2002/08/02/big/img_830 +2002/08/23/big/img_250 +2002/07/20/big/img_777 +2002/08/21/big/img_879 +2002/08/26/big/img_20146 +2002/08/23/big/img_789 +2002/08/06/big/img_2683 +2002/08/25/big/img_576 +2002/08/09/big/img_498 +2002/08/08/big/img_384 +2002/08/26/big/img_592 +2002/07/29/big/img_1470 +2002/08/21/big/img_452 +2002/08/30/big/img_18395 +2002/08/15/big/img_215 +2002/07/21/big/img_643 +2002/07/22/big/img_209 +2003/01/17/big/img_346 +2002/08/25/big/img_658 +2002/08/21/big/img_221 +2002/08/14/big/img_60 +2003/01/17/big/img_885 +2003/01/16/big/img_482 +2002/08/19/big/img_593 +2002/08/08/big/img_233 +2002/07/30/big/img_458 +2002/07/23/big/img_384 +2003/01/15/big/img_670 +2003/01/15/big/img_267 +2002/08/26/big/img_540 +2002/07/29/big/img_552 +2002/07/30/big/img_997 +2003/01/17/big/img_377 +2002/08/21/big/img_265 +2002/08/09/big/img_561 +2002/07/31/big/img_945 +2002/09/02/big/img_15252 +2002/08/11/big/img_276 +2002/07/22/big/img_491 +2002/07/26/big/img_517 +2002/08/14/big/img_726 +2002/08/08/big/img_46 +2002/08/28/big/img_19458 +2002/08/06/big/img_2935 +2002/07/29/big/img_1392 +2002/08/13/big/img_776 +2002/08/24/big/img_616 +2002/08/14/big/img_1065 +2002/07/29/big/img_889 +2002/08/18/big/img_188 +2002/08/07/big/img_1453 +2002/08/02/big/img_760 +2002/07/28/big/img_416 +2002/08/07/big/img_1393 +2002/08/26/big/img_292 +2002/08/26/big/img_301 +2003/01/13/big/img_195 +2002/07/26/big/img_532 +2002/08/20/big/img_550 +2002/08/05/big/img_3658 +2002/08/26/big/img_738 +2002/09/02/big/img_15750 +2003/01/17/big/img_451 +2002/07/23/big/img_339 +2002/08/16/big/img_637 +2002/08/14/big/img_748 +2002/08/06/big/img_2739 +2002/07/25/big/img_482 +2002/08/19/big/img_191 +2002/08/26/big/img_537 +2003/01/15/big/img_716 +2003/01/15/big/img_767 +2002/08/02/big/img_452 +2002/08/08/big/img_1011 +2002/08/10/big/img_144 +2003/01/14/big/img_122 +2002/07/24/big/img_586 +2002/07/24/big/img_762 +2002/08/20/big/img_369 +2002/07/30/big/img_146 +2002/08/23/big/img_396 +2003/01/15/big/img_200 +2002/08/15/big/img_1183 +2003/01/14/big/img_698 +2002/08/09/big/img_792 +2002/08/06/big/img_2347 +2002/07/31/big/img_911 +2002/08/26/big/img_722 +2002/08/23/big/img_621 +2002/08/05/big/img_3790 +2003/01/13/big/img_633 +2002/08/09/big/img_224 +2002/07/24/big/img_454 +2002/07/21/big/img_202 +2002/08/02/big/img_630 +2002/08/30/big/img_18315 +2002/07/19/big/img_491 +2002/09/01/big/img_16456 +2002/08/09/big/img_242 +2002/07/25/big/img_595 +2002/07/22/big/img_522 +2002/08/01/big/img_1593 +2002/07/29/big/img_336 +2002/08/15/big/img_448 +2002/08/28/big/img_19281 +2002/07/29/big/img_342 +2002/08/12/big/img_78 +2003/01/14/big/img_525 +2002/07/28/big/img_147 +2002/08/11/big/img_353 +2002/08/22/big/img_513 +2002/08/04/big/img_721 +2002/08/17/big/img_247 +2003/01/14/big/img_891 +2002/08/20/big/img_853 +2002/07/19/big/img_414 +2002/08/01/big/img_1530 +2003/01/14/big/img_924 +2002/08/22/big/img_468 +2002/08/18/big/img_354 +2002/08/30/big/img_18193 +2002/08/23/big/img_492 +2002/08/15/big/img_871 +2002/08/12/big/img_494 +2002/08/06/big/img_2470 +2002/07/23/big/img_923 +2002/08/26/big/img_155 +2002/08/08/big/img_669 +2002/07/23/big/img_404 +2002/08/28/big/img_19421 +2002/08/29/big/img_18993 +2002/08/25/big/img_416 +2003/01/17/big/img_434 +2002/07/29/big/img_1370 +2002/07/28/big/img_483 +2002/08/11/big/img_50 +2002/08/10/big/img_404 +2002/09/02/big/img_15057 +2003/01/14/big/img_911 +2002/09/01/big/img_16697 +2003/01/16/big/img_665 +2002/09/01/big/img_16708 +2002/08/22/big/img_612 +2002/08/28/big/img_19471 +2002/08/02/big/img_198 +2003/01/16/big/img_527 +2002/08/22/big/img_209 +2002/08/30/big/img_18205 +2003/01/14/big/img_114 +2003/01/14/big/img_1028 +2003/01/16/big/img_894 +2003/01/14/big/img_837 +2002/07/30/big/img_9 +2002/08/06/big/img_2821 +2002/08/04/big/img_85 +2003/01/13/big/img_884 +2002/07/22/big/img_570 +2002/08/07/big/img_1773 +2002/07/26/big/img_208 +2003/01/17/big/img_946 +2002/07/19/big/img_930 +2003/01/01/big/img_698 +2003/01/17/big/img_612 +2002/07/19/big/img_372 +2002/07/30/big/img_721 +2003/01/14/big/img_649 +2002/08/19/big/img_4 +2002/07/25/big/img_1024 +2003/01/15/big/img_601 +2002/08/30/big/img_18470 +2002/07/22/big/img_29 +2002/08/07/big/img_1686 +2002/07/20/big/img_294 +2002/08/14/big/img_800 +2002/08/19/big/img_353 +2002/08/19/big/img_350 +2002/08/05/big/img_3392 +2002/08/09/big/img_622 +2003/01/15/big/img_236 +2002/08/11/big/img_643 +2002/08/05/big/img_3458 +2002/08/12/big/img_413 +2002/08/22/big/img_415 +2002/08/13/big/img_635 +2002/08/07/big/img_1198 +2002/08/04/big/img_873 +2002/08/12/big/img_407 +2003/01/15/big/img_346 +2002/08/02/big/img_275 +2002/08/17/big/img_997 +2002/08/21/big/img_958 +2002/08/20/big/img_579 +2002/07/29/big/img_142 +2003/01/14/big/img_1115 +2002/08/16/big/img_365 +2002/07/29/big/img_1414 +2002/08/17/big/img_489 +2002/08/13/big/img_1010 +2002/07/31/big/img_276 +2002/07/25/big/img_1000 +2002/08/23/big/img_524 +2002/08/28/big/img_19147 +2003/01/13/big/img_433 +2002/08/20/big/img_205 +2003/01/01/big/img_458 +2002/07/29/big/img_1449 +2003/01/16/big/img_696 +2002/08/28/big/img_19296 +2002/08/29/big/img_18688 +2002/08/21/big/img_767 +2002/08/20/big/img_532 +2002/08/26/big/img_187 +2002/07/26/big/img_183 +2002/07/27/big/img_890 +2003/01/13/big/img_576 +2002/07/30/big/img_15 +2002/07/31/big/img_889 +2002/08/31/big/img_17759 +2003/01/14/big/img_1114 +2002/07/19/big/img_445 +2002/08/03/big/img_593 +2002/07/24/big/img_750 +2002/07/30/big/img_133 +2002/08/25/big/img_671 +2002/07/20/big/img_351 +2002/08/31/big/img_17276 +2002/08/05/big/img_3231 +2002/09/02/big/img_15882 +2002/08/14/big/img_115 +2002/08/02/big/img_1148 +2002/07/25/big/img_936 +2002/07/31/big/img_639 +2002/08/04/big/img_427 +2002/08/22/big/img_843 +2003/01/17/big/img_17 +2003/01/13/big/img_690 +2002/08/13/big/img_472 +2002/08/09/big/img_425 +2002/08/05/big/img_3450 +2003/01/17/big/img_439 +2002/08/13/big/img_539 +2002/07/28/big/img_35 +2002/08/16/big/img_241 +2002/08/06/big/img_2898 +2003/01/16/big/img_429 +2002/08/05/big/img_3817 +2002/08/27/big/img_19919 +2002/07/19/big/img_422 +2002/08/15/big/img_560 +2002/07/23/big/img_750 +2002/07/30/big/img_353 +2002/08/05/big/img_43 +2002/08/23/big/img_305 +2002/08/01/big/img_2137 +2002/08/30/big/img_18097 +2002/08/01/big/img_1389 +2002/08/02/big/img_308 +2003/01/14/big/img_652 +2002/08/01/big/img_1798 +2003/01/14/big/img_732 +2003/01/16/big/img_294 +2002/08/26/big/img_213 +2002/07/24/big/img_842 +2003/01/13/big/img_630 +2003/01/13/big/img_634 +2002/08/06/big/img_2285 +2002/08/01/big/img_2162 +2002/08/30/big/img_18134 +2002/08/02/big/img_1045 +2002/08/01/big/img_2143 +2002/07/25/big/img_135 +2002/07/20/big/img_645 +2002/08/05/big/img_3666 +2002/08/14/big/img_523 +2002/08/04/big/img_425 +2003/01/14/big/img_137 +2003/01/01/big/img_176 +2002/08/15/big/img_505 +2002/08/24/big/img_386 +2002/08/05/big/img_3187 +2002/08/15/big/img_419 +2003/01/13/big/img_520 +2002/08/04/big/img_444 +2002/08/26/big/img_483 +2002/08/05/big/img_3449 +2002/08/30/big/img_18409 +2002/08/28/big/img_19455 +2002/08/27/big/img_20090 +2002/07/23/big/img_625 +2002/08/24/big/img_205 +2002/08/08/big/img_938 +2003/01/13/big/img_527 +2002/08/07/big/img_1712 +2002/07/24/big/img_801 +2002/08/09/big/img_579 +2003/01/14/big/img_41 +2003/01/15/big/img_1130 +2002/07/21/big/img_672 +2002/08/07/big/img_1590 +2003/01/01/big/img_532 +2002/08/02/big/img_529 +2002/08/05/big/img_3591 +2002/08/23/big/img_5 +2003/01/14/big/img_882 +2002/08/28/big/img_19234 +2002/07/24/big/img_398 +2003/01/14/big/img_592 +2002/08/22/big/img_548 +2002/08/12/big/img_761 +2003/01/16/big/img_497 +2002/08/18/big/img_133 +2002/08/08/big/img_874 +2002/07/19/big/img_247 +2002/08/15/big/img_170 +2002/08/27/big/img_19679 +2002/08/20/big/img_246 +2002/08/24/big/img_358 +2002/07/29/big/img_599 +2002/08/01/big/img_1555 +2002/07/30/big/img_491 +2002/07/30/big/img_371 +2003/01/16/big/img_682 +2002/07/25/big/img_619 +2003/01/15/big/img_587 +2002/08/02/big/img_1212 +2002/08/01/big/img_2152 +2002/07/25/big/img_668 +2003/01/16/big/img_574 +2002/08/28/big/img_19464 +2002/08/11/big/img_536 +2002/07/24/big/img_201 +2002/08/05/big/img_3488 +2002/07/25/big/img_887 +2002/07/22/big/img_789 +2002/07/30/big/img_432 +2002/08/16/big/img_166 +2002/09/01/big/img_16333 +2002/07/26/big/img_1010 +2002/07/21/big/img_793 +2002/07/22/big/img_720 +2002/07/31/big/img_337 +2002/07/27/big/img_185 +2002/08/23/big/img_440 +2002/07/31/big/img_801 +2002/07/25/big/img_478 +2003/01/14/big/img_171 +2002/08/07/big/img_1054 +2002/09/02/big/img_15659 +2002/07/29/big/img_1348 +2002/08/09/big/img_337 +2002/08/26/big/img_684 +2002/07/31/big/img_537 +2002/08/15/big/img_808 +2003/01/13/big/img_740 +2002/08/07/big/img_1667 +2002/08/03/big/img_404 +2002/08/06/big/img_2520 +2002/07/19/big/img_230 +2002/07/19/big/img_356 +2003/01/16/big/img_627 +2002/08/04/big/img_474 +2002/07/29/big/img_833 +2002/07/25/big/img_176 +2002/08/01/big/img_1684 +2002/08/21/big/img_643 +2002/08/27/big/img_19673 +2002/08/02/big/img_838 +2002/08/06/big/img_2378 +2003/01/15/big/img_48 +2002/07/30/big/img_470 +2002/08/15/big/img_963 +2002/08/24/big/img_444 +2002/08/16/big/img_662 +2002/08/15/big/img_1209 +2002/07/24/big/img_25 +2002/08/06/big/img_2740 +2002/07/29/big/img_996 +2002/08/31/big/img_18074 +2002/08/04/big/img_343 +2003/01/17/big/img_509 +2003/01/13/big/img_726 +2002/08/07/big/img_1466 +2002/07/26/big/img_307 +2002/08/10/big/img_598 +2002/08/13/big/img_890 +2002/08/14/big/img_997 +2002/07/19/big/img_392 +2002/08/02/big/img_475 +2002/08/29/big/img_19038 +2002/07/29/big/img_538 +2002/07/29/big/img_502 +2002/08/02/big/img_364 +2002/08/31/big/img_17353 +2002/08/08/big/img_539 +2002/08/01/big/img_1449 +2002/07/22/big/img_363 +2002/08/02/big/img_90 +2002/09/01/big/img_16867 +2002/08/05/big/img_3371 +2002/07/30/big/img_342 +2002/08/07/big/img_1363 +2002/08/22/big/img_790 +2003/01/15/big/img_404 +2002/08/05/big/img_3447 +2002/09/01/big/img_16167 +2003/01/13/big/img_840 +2002/08/22/big/img_1001 +2002/08/09/big/img_431 +2002/07/27/big/img_618 +2002/07/31/big/img_741 +2002/07/30/big/img_964 +2002/07/25/big/img_86 +2002/07/29/big/img_275 +2002/08/21/big/img_921 +2002/07/26/big/img_892 +2002/08/21/big/img_663 +2003/01/13/big/img_567 +2003/01/14/big/img_719 +2002/07/28/big/img_251 +2003/01/15/big/img_1123 +2002/07/29/big/img_260 +2002/08/24/big/img_337 +2002/08/01/big/img_1914 +2002/08/13/big/img_373 +2003/01/15/big/img_589 +2002/08/13/big/img_906 +2002/07/26/big/img_270 +2002/08/26/big/img_313 +2002/08/25/big/img_694 +2003/01/01/big/img_327 +2002/07/23/big/img_261 +2002/08/26/big/img_642 +2002/07/29/big/img_918 +2002/07/23/big/img_455 +2002/07/24/big/img_612 +2002/07/23/big/img_534 +2002/07/19/big/img_534 +2002/07/19/big/img_726 +2002/08/01/big/img_2146 +2002/08/02/big/img_543 +2003/01/16/big/img_777 +2002/07/30/big/img_484 +2002/08/13/big/img_1161 +2002/07/21/big/img_390 +2002/08/06/big/img_2288 +2002/08/21/big/img_677 +2002/08/13/big/img_747 +2002/08/15/big/img_1248 +2002/07/31/big/img_416 +2002/09/02/big/img_15259 +2002/08/16/big/img_781 +2002/08/24/big/img_754 +2002/07/24/big/img_803 +2002/08/20/big/img_609 +2002/08/28/big/img_19571 +2002/09/01/big/img_16140 +2002/08/26/big/img_769 +2002/07/20/big/img_588 +2002/08/02/big/img_898 +2002/07/21/big/img_466 +2002/08/14/big/img_1046 +2002/07/25/big/img_212 +2002/08/26/big/img_353 +2002/08/19/big/img_810 +2002/08/31/big/img_17824 +2002/08/12/big/img_631 +2002/07/19/big/img_828 +2002/07/24/big/img_130 +2002/08/25/big/img_580 +2002/07/31/big/img_699 +2002/07/23/big/img_808 +2002/07/31/big/img_377 +2003/01/16/big/img_570 +2002/09/01/big/img_16254 +2002/07/21/big/img_471 +2002/08/01/big/img_1548 +2002/08/18/big/img_252 +2002/08/19/big/img_576 +2002/08/20/big/img_464 +2002/07/27/big/img_735 +2002/08/21/big/img_589 +2003/01/15/big/img_1192 +2002/08/09/big/img_302 +2002/07/31/big/img_594 +2002/08/23/big/img_19 +2002/08/29/big/img_18819 +2002/08/19/big/img_293 +2002/07/30/big/img_331 +2002/08/23/big/img_607 +2002/07/30/big/img_363 +2002/08/16/big/img_766 +2003/01/13/big/img_481 +2002/08/06/big/img_2515 +2002/09/02/big/img_15913 +2002/09/02/big/img_15827 +2002/09/02/big/img_15053 +2002/08/07/big/img_1576 +2002/07/23/big/img_268 +2002/08/21/big/img_152 +2003/01/15/big/img_578 +2002/07/21/big/img_589 +2002/07/20/big/img_548 +2002/08/27/big/img_19693 +2002/08/31/big/img_17252 +2002/07/31/big/img_138 +2002/07/23/big/img_372 +2002/08/16/big/img_695 +2002/07/27/big/img_287 +2002/08/15/big/img_315 +2002/08/10/big/img_361 +2002/07/29/big/img_899 +2002/08/13/big/img_771 +2002/08/21/big/img_92 +2003/01/15/big/img_425 +2003/01/16/big/img_450 +2002/09/01/big/img_16942 +2002/08/02/big/img_51 +2002/09/02/big/img_15379 +2002/08/24/big/img_147 +2002/08/30/big/img_18122 +2002/07/26/big/img_950 +2002/08/07/big/img_1400 +2002/08/17/big/img_468 +2002/08/15/big/img_470 +2002/07/30/big/img_318 +2002/07/22/big/img_644 +2002/08/27/big/img_19732 +2002/07/23/big/img_601 +2002/08/26/big/img_398 +2002/08/21/big/img_428 +2002/08/06/big/img_2119 +2002/08/29/big/img_19103 +2003/01/14/big/img_933 +2002/08/11/big/img_674 +2002/08/28/big/img_19420 +2002/08/03/big/img_418 +2002/08/17/big/img_312 +2002/07/25/big/img_1044 +2003/01/17/big/img_671 +2002/08/30/big/img_18297 +2002/07/25/big/img_755 +2002/07/23/big/img_471 +2002/08/21/big/img_39 +2002/07/26/big/img_699 +2003/01/14/big/img_33 +2002/07/31/big/img_411 +2002/08/16/big/img_645 +2003/01/17/big/img_116 +2002/09/02/big/img_15903 +2002/08/20/big/img_120 +2002/08/22/big/img_176 +2002/07/29/big/img_1316 +2002/08/27/big/img_19914 +2002/07/22/big/img_719 +2002/08/28/big/img_19239 +2003/01/13/big/img_385 +2002/08/08/big/img_525 +2002/07/19/big/img_782 +2002/08/13/big/img_843 +2002/07/30/big/img_107 +2002/08/11/big/img_752 +2002/07/29/big/img_383 +2002/08/26/big/img_249 +2002/08/29/big/img_18860 +2002/07/30/big/img_70 +2002/07/26/big/img_194 +2002/08/15/big/img_530 +2002/08/08/big/img_816 +2002/07/31/big/img_286 +2003/01/13/big/img_294 +2002/07/31/big/img_251 +2002/07/24/big/img_13 +2002/08/31/big/img_17938 +2002/07/22/big/img_642 +2003/01/14/big/img_728 +2002/08/18/big/img_47 +2002/08/22/big/img_306 +2002/08/20/big/img_348 +2002/08/15/big/img_764 +2002/08/08/big/img_163 +2002/07/23/big/img_531 +2002/07/23/big/img_467 +2003/01/16/big/img_743 +2003/01/13/big/img_535 +2002/08/02/big/img_523 +2002/08/22/big/img_120 +2002/08/11/big/img_496 +2002/08/29/big/img_19075 +2002/08/08/big/img_465 +2002/08/09/big/img_790 +2002/08/19/big/img_588 +2002/08/23/big/img_407 +2003/01/17/big/img_435 +2002/08/24/big/img_398 +2002/08/27/big/img_19899 +2003/01/15/big/img_335 +2002/08/13/big/img_493 +2002/09/02/big/img_15460 +2002/07/31/big/img_470 +2002/08/05/big/img_3550 +2002/07/28/big/img_123 +2002/08/01/big/img_1498 +2002/08/04/big/img_504 +2003/01/17/big/img_427 +2002/08/27/big/img_19708 +2002/07/27/big/img_861 +2002/07/25/big/img_685 +2002/07/31/big/img_207 +2003/01/14/big/img_745 +2002/08/31/big/img_17756 +2002/08/24/big/img_288 +2002/08/18/big/img_181 +2002/08/10/big/img_520 +2002/08/25/big/img_705 +2002/08/23/big/img_226 +2002/08/04/big/img_727 +2002/07/24/big/img_625 +2002/08/28/big/img_19157 +2002/08/23/big/img_586 +2002/07/31/big/img_232 +2003/01/13/big/img_240 +2003/01/14/big/img_321 +2003/01/15/big/img_533 +2002/07/23/big/img_480 +2002/07/24/big/img_371 +2002/08/21/big/img_702 +2002/08/31/big/img_17075 +2002/09/02/big/img_15278 +2002/07/29/big/img_246 +2003/01/15/big/img_829 +2003/01/15/big/img_1213 +2003/01/16/big/img_441 +2002/08/14/big/img_921 +2002/07/23/big/img_425 +2002/08/15/big/img_296 +2002/07/19/big/img_135 +2002/07/26/big/img_402 +2003/01/17/big/img_88 +2002/08/20/big/img_872 +2002/08/13/big/img_1110 +2003/01/16/big/img_1040 +2002/07/23/big/img_9 +2002/08/13/big/img_700 +2002/08/16/big/img_371 +2002/08/27/big/img_19966 +2003/01/17/big/img_391 +2002/08/18/big/img_426 +2002/08/01/big/img_1618 +2002/07/21/big/img_754 +2003/01/14/big/img_1101 +2003/01/16/big/img_1022 +2002/07/22/big/img_275 +2002/08/24/big/img_86 +2002/08/17/big/img_582 +2003/01/15/big/img_765 +2003/01/17/big/img_449 +2002/07/28/big/img_265 +2003/01/13/big/img_552 +2002/07/28/big/img_115 +2003/01/16/big/img_56 +2002/08/02/big/img_1232 +2003/01/17/big/img_925 +2002/07/22/big/img_445 +2002/07/25/big/img_957 +2002/07/20/big/img_589 +2002/08/31/big/img_17107 +2002/07/29/big/img_483 +2002/08/14/big/img_1063 +2002/08/07/big/img_1545 +2002/08/14/big/img_680 +2002/09/01/big/img_16694 +2002/08/14/big/img_257 +2002/08/11/big/img_726 +2002/07/26/big/img_681 +2002/07/25/big/img_481 +2003/01/14/big/img_737 +2002/08/28/big/img_19480 +2003/01/16/big/img_362 +2002/08/27/big/img_19865 +2003/01/01/big/img_547 +2002/09/02/big/img_15074 +2002/08/01/big/img_1453 +2002/08/22/big/img_594 +2002/08/28/big/img_19263 +2002/08/13/big/img_478 +2002/07/29/big/img_1358 +2003/01/14/big/img_1022 +2002/08/16/big/img_450 +2002/08/02/big/img_159 +2002/07/26/big/img_781 +2003/01/13/big/img_601 +2002/08/20/big/img_407 +2002/08/15/big/img_468 +2002/08/31/big/img_17902 +2002/08/16/big/img_81 +2002/07/25/big/img_987 +2002/07/25/big/img_500 +2002/08/02/big/img_31 +2002/08/18/big/img_538 +2002/08/08/big/img_54 +2002/07/23/big/img_686 +2002/07/24/big/img_836 +2003/01/17/big/img_734 +2002/08/16/big/img_1055 +2003/01/16/big/img_521 +2002/07/25/big/img_612 +2002/08/22/big/img_778 +2002/08/03/big/img_251 +2002/08/12/big/img_436 +2002/08/23/big/img_705 +2002/07/28/big/img_243 +2002/07/25/big/img_1029 +2002/08/20/big/img_287 +2002/08/29/big/img_18739 +2002/08/05/big/img_3272 +2002/07/27/big/img_214 +2003/01/14/big/img_5 +2002/08/01/big/img_1380 +2002/08/29/big/img_19097 +2002/07/30/big/img_486 +2002/08/29/big/img_18707 +2002/08/10/big/img_559 +2002/08/15/big/img_365 +2002/08/09/big/img_525 +2002/08/10/big/img_689 +2002/07/25/big/img_502 +2002/08/03/big/img_667 +2002/08/10/big/img_855 +2002/08/10/big/img_706 +2002/08/18/big/img_603 +2003/01/16/big/img_1055 +2002/08/31/big/img_17890 +2002/08/15/big/img_761 +2003/01/15/big/img_489 +2002/08/26/big/img_351 +2002/08/01/big/img_1772 +2002/08/31/big/img_17729 +2002/07/25/big/img_609 +2003/01/13/big/img_539 +2002/07/27/big/img_686 +2002/07/31/big/img_311 +2002/08/22/big/img_799 +2003/01/16/big/img_936 +2002/08/31/big/img_17813 +2002/08/04/big/img_862 +2002/08/09/big/img_332 +2002/07/20/big/img_148 +2002/08/12/big/img_426 +2002/07/24/big/img_69 +2002/07/27/big/img_685 +2002/08/02/big/img_480 +2002/08/26/big/img_154 +2002/07/24/big/img_598 +2002/08/01/big/img_1881 +2002/08/20/big/img_667 +2003/01/14/big/img_495 +2002/07/21/big/img_744 +2002/07/30/big/img_150 +2002/07/23/big/img_924 +2002/08/08/big/img_272 +2002/07/23/big/img_310 +2002/07/25/big/img_1011 +2002/09/02/big/img_15725 +2002/07/19/big/img_814 +2002/08/20/big/img_936 +2002/07/25/big/img_85 +2002/08/24/big/img_662 +2002/08/09/big/img_495 +2003/01/15/big/img_196 +2002/08/16/big/img_707 +2002/08/28/big/img_19370 +2002/08/06/big/img_2366 +2002/08/06/big/img_3012 +2002/08/01/big/img_1452 +2002/07/31/big/img_742 +2002/07/27/big/img_914 +2003/01/13/big/img_290 +2002/07/31/big/img_288 +2002/08/02/big/img_171 +2002/08/22/big/img_191 +2002/07/27/big/img_1066 +2002/08/12/big/img_383 +2003/01/17/big/img_1018 +2002/08/01/big/img_1785 +2002/08/11/big/img_390 +2002/08/27/big/img_20037 +2002/08/12/big/img_38 +2003/01/15/big/img_103 +2002/08/26/big/img_31 +2002/08/18/big/img_660 +2002/07/22/big/img_694 +2002/08/15/big/img_24 +2002/07/27/big/img_1077 +2002/08/01/big/img_1943 +2002/07/22/big/img_292 +2002/09/01/big/img_16857 +2002/07/22/big/img_892 +2003/01/14/big/img_46 +2002/08/09/big/img_469 +2002/08/09/big/img_414 +2003/01/16/big/img_40 +2002/08/28/big/img_19231 +2002/07/27/big/img_978 +2002/07/23/big/img_475 +2002/07/25/big/img_92 +2002/08/09/big/img_799 +2002/07/25/big/img_491 +2002/08/03/big/img_654 +2003/01/15/big/img_687 +2002/08/11/big/img_478 +2002/08/07/big/img_1664 +2002/08/20/big/img_362 +2002/08/01/big/img_1298 +2003/01/13/big/img_500 +2002/08/06/big/img_2896 +2002/08/30/big/img_18529 +2002/08/16/big/img_1020 +2002/07/29/big/img_892 +2002/08/29/big/img_18726 +2002/07/21/big/img_453 +2002/08/17/big/img_437 +2002/07/19/big/img_665 +2002/07/22/big/img_440 +2002/07/19/big/img_582 +2002/07/21/big/img_233 +2003/01/01/big/img_82 +2002/07/25/big/img_341 +2002/07/29/big/img_864 +2002/08/02/big/img_276 +2002/08/29/big/img_18654 +2002/07/27/big/img_1024 +2002/08/19/big/img_373 +2003/01/15/big/img_241 +2002/07/25/big/img_84 +2002/08/13/big/img_834 +2002/08/10/big/img_511 +2002/08/01/big/img_1627 +2002/08/08/big/img_607 +2002/08/06/big/img_2083 +2002/08/01/big/img_1486 +2002/08/08/big/img_700 +2002/08/01/big/img_1954 +2002/08/21/big/img_54 +2002/07/30/big/img_847 +2002/08/28/big/img_19169 +2002/07/21/big/img_549 +2002/08/03/big/img_693 +2002/07/31/big/img_1002 +2003/01/14/big/img_1035 +2003/01/16/big/img_622 +2002/07/30/big/img_1201 +2002/08/10/big/img_444 +2002/07/31/big/img_374 +2002/08/21/big/img_301 +2002/08/13/big/img_1095 +2003/01/13/big/img_288 +2002/07/25/big/img_232 +2003/01/13/big/img_967 +2002/08/26/big/img_360 +2002/08/05/big/img_67 +2002/08/29/big/img_18969 +2002/07/28/big/img_16 +2002/08/16/big/img_515 +2002/07/20/big/img_708 +2002/08/18/big/img_178 +2003/01/15/big/img_509 +2002/07/25/big/img_430 +2002/08/21/big/img_738 +2002/08/16/big/img_886 +2002/09/02/big/img_15605 +2002/09/01/big/img_16242 +2002/08/24/big/img_711 +2002/07/25/big/img_90 +2002/08/09/big/img_491 +2002/07/30/big/img_534 +2003/01/13/big/img_474 +2002/08/25/big/img_510 +2002/08/15/big/img_555 +2002/08/02/big/img_775 +2002/07/23/big/img_975 +2002/08/19/big/img_229 +2003/01/17/big/img_860 +2003/01/02/big/img_10 +2002/07/23/big/img_542 +2002/08/06/big/img_2535 +2002/07/22/big/img_37 +2002/08/06/big/img_2342 +2002/08/25/big/img_515 +2002/08/25/big/img_336 +2002/08/18/big/img_837 +2002/08/21/big/img_616 +2003/01/17/big/img_24 +2002/07/26/big/img_936 +2002/08/14/big/img_896 +2002/07/29/big/img_465 +2002/07/31/big/img_543 +2002/08/01/big/img_1411 +2002/08/02/big/img_423 +2002/08/21/big/img_44 +2002/07/31/big/img_11 +2003/01/15/big/img_628 +2003/01/15/big/img_605 +2002/07/30/big/img_571 +2002/07/23/big/img_428 +2002/08/15/big/img_942 +2002/07/26/big/img_531 +2003/01/16/big/img_59 +2002/08/02/big/img_410 +2002/07/31/big/img_230 +2002/08/19/big/img_806 +2003/01/14/big/img_462 +2002/08/16/big/img_370 +2002/08/13/big/img_380 +2002/08/16/big/img_932 +2002/07/19/big/img_393 +2002/08/20/big/img_764 +2002/08/15/big/img_616 +2002/07/26/big/img_267 +2002/07/27/big/img_1069 +2002/08/14/big/img_1041 +2003/01/13/big/img_594 +2002/09/01/big/img_16845 +2002/08/09/big/img_229 +2003/01/16/big/img_639 +2002/08/19/big/img_398 +2002/08/18/big/img_978 +2002/08/24/big/img_296 +2002/07/29/big/img_415 +2002/07/30/big/img_923 +2002/08/18/big/img_575 +2002/08/22/big/img_182 +2002/07/25/big/img_806 +2002/07/22/big/img_49 +2002/07/29/big/img_989 +2003/01/17/big/img_789 +2003/01/15/big/img_503 +2002/09/01/big/img_16062 +2003/01/17/big/img_794 +2002/08/15/big/img_564 +2003/01/15/big/img_222 +2002/08/01/big/img_1656 +2003/01/13/big/img_432 +2002/07/19/big/img_426 +2002/08/17/big/img_244 +2002/08/13/big/img_805 +2002/09/02/big/img_15067 +2002/08/11/big/img_58 +2002/08/22/big/img_636 +2002/07/22/big/img_416 +2002/08/13/big/img_836 +2002/08/26/big/img_363 +2002/07/30/big/img_917 +2003/01/14/big/img_206 +2002/08/12/big/img_311 +2002/08/31/big/img_17623 +2002/07/29/big/img_661 +2003/01/13/big/img_417 +2002/08/02/big/img_463 +2002/08/02/big/img_669 +2002/08/26/big/img_670 +2002/08/02/big/img_375 +2002/07/19/big/img_209 +2002/08/08/big/img_115 +2002/08/21/big/img_399 +2002/08/20/big/img_911 +2002/08/07/big/img_1212 +2002/08/20/big/img_578 +2002/08/22/big/img_554 +2002/08/21/big/img_484 +2002/07/25/big/img_450 +2002/08/03/big/img_542 +2002/08/15/big/img_561 +2002/07/23/big/img_360 +2002/08/30/big/img_18137 +2002/07/25/big/img_250 +2002/08/03/big/img_647 +2002/08/20/big/img_375 +2002/08/14/big/img_387 +2002/09/01/big/img_16990 +2002/08/28/big/img_19341 +2003/01/15/big/img_239 +2002/08/20/big/img_528 +2002/08/12/big/img_130 +2002/09/02/big/img_15108 +2003/01/15/big/img_372 +2002/08/16/big/img_678 +2002/08/04/big/img_623 +2002/07/23/big/img_477 +2002/08/28/big/img_19590 +2003/01/17/big/img_978 +2002/09/01/big/img_16692 +2002/07/20/big/img_109 +2002/08/06/big/img_2660 +2003/01/14/big/img_464 +2002/08/09/big/img_618 +2002/07/22/big/img_722 +2002/08/25/big/img_419 +2002/08/03/big/img_314 +2002/08/25/big/img_40 +2002/07/27/big/img_430 +2002/08/10/big/img_569 +2002/08/23/big/img_398 +2002/07/23/big/img_893 +2002/08/16/big/img_261 +2002/08/06/big/img_2668 +2002/07/22/big/img_835 +2002/09/02/big/img_15093 +2003/01/16/big/img_65 +2002/08/21/big/img_448 +2003/01/14/big/img_351 +2003/01/17/big/img_133 +2002/07/28/big/img_493 +2003/01/15/big/img_640 +2002/09/01/big/img_16880 +2002/08/15/big/img_350 +2002/08/20/big/img_624 +2002/08/25/big/img_604 +2002/08/06/big/img_2200 +2002/08/23/big/img_290 +2002/08/13/big/img_1152 +2003/01/14/big/img_251 +2002/08/02/big/img_538 +2002/08/22/big/img_613 +2003/01/13/big/img_351 +2002/08/18/big/img_368 +2002/07/23/big/img_392 +2002/07/25/big/img_198 +2002/07/25/big/img_418 +2002/08/26/big/img_614 +2002/07/23/big/img_405 +2003/01/14/big/img_445 +2002/07/25/big/img_326 +2002/08/10/big/img_734 +2003/01/14/big/img_530 +2002/08/08/big/img_561 +2002/08/29/big/img_18990 +2002/08/10/big/img_576 +2002/07/29/big/img_1494 +2002/07/19/big/img_198 +2002/08/10/big/img_562 +2002/07/22/big/img_901 +2003/01/14/big/img_37 +2002/09/02/big/img_15629 +2003/01/14/big/img_58 +2002/08/01/big/img_1364 +2002/07/27/big/img_636 +2003/01/13/big/img_241 +2002/09/01/big/img_16988 +2003/01/13/big/img_560 +2002/08/09/big/img_533 +2002/07/31/big/img_249 +2003/01/17/big/img_1007 +2002/07/21/big/img_64 +2003/01/13/big/img_537 +2003/01/15/big/img_606 +2002/08/18/big/img_651 +2002/08/24/big/img_405 +2002/07/26/big/img_837 +2002/08/09/big/img_562 +2002/08/01/big/img_1983 +2002/08/03/big/img_514 +2002/07/29/big/img_314 +2002/08/12/big/img_493 +2003/01/14/big/img_121 +2003/01/14/big/img_479 +2002/08/04/big/img_410 +2002/07/22/big/img_607 +2003/01/17/big/img_417 +2002/07/20/big/img_547 +2002/08/13/big/img_396 +2002/08/31/big/img_17538 +2002/08/13/big/img_187 +2002/08/12/big/img_328 +2003/01/14/big/img_569 +2002/07/27/big/img_1081 +2002/08/14/big/img_504 +2002/08/23/big/img_785 +2002/07/26/big/img_339 +2002/08/07/big/img_1156 +2002/08/07/big/img_1456 +2002/08/23/big/img_378 +2002/08/27/big/img_19719 +2002/07/31/big/img_39 +2002/07/31/big/img_883 +2003/01/14/big/img_676 +2002/07/29/big/img_214 +2002/07/26/big/img_669 +2002/07/25/big/img_202 +2002/08/08/big/img_259 +2003/01/17/big/img_943 +2003/01/15/big/img_512 +2002/08/05/big/img_3295 +2002/08/27/big/img_19685 +2002/08/08/big/img_277 +2002/08/30/big/img_18154 +2002/07/22/big/img_663 +2002/08/29/big/img_18914 +2002/07/31/big/img_908 +2002/08/27/big/img_19926 +2003/01/13/big/img_791 +2003/01/15/big/img_827 +2002/08/18/big/img_878 +2002/08/14/big/img_670 +2002/07/20/big/img_182 +2002/08/15/big/img_291 +2002/08/06/big/img_2600 +2002/07/23/big/img_587 +2002/08/14/big/img_577 +2003/01/15/big/img_585 +2002/07/30/big/img_310 +2002/08/03/big/img_658 +2002/08/10/big/img_157 +2002/08/19/big/img_811 +2002/07/29/big/img_1318 +2002/08/04/big/img_104 +2002/07/30/big/img_332 +2002/07/24/big/img_789 +2002/07/29/big/img_516 +2002/07/23/big/img_843 +2002/08/01/big/img_1528 +2002/08/13/big/img_798 +2002/08/07/big/img_1729 +2002/08/28/big/img_19448 +2003/01/16/big/img_95 +2002/08/12/big/img_473 +2002/07/27/big/img_269 +2003/01/16/big/img_621 +2002/07/29/big/img_772 +2002/07/24/big/img_171 +2002/07/19/big/img_429 +2002/08/07/big/img_1933 +2002/08/27/big/img_19629 +2002/08/05/big/img_3688 +2002/08/07/big/img_1691 +2002/07/23/big/img_600 +2002/07/29/big/img_666 +2002/08/25/big/img_566 +2002/08/06/big/img_2659 +2002/08/29/big/img_18929 +2002/08/16/big/img_407 +2002/08/18/big/img_774 +2002/08/19/big/img_249 +2002/08/06/big/img_2427 +2002/08/29/big/img_18899 +2002/08/01/big/img_1818 +2002/07/31/big/img_108 +2002/07/29/big/img_500 +2002/08/11/big/img_115 +2002/07/19/big/img_521 +2002/08/02/big/img_1163 +2002/07/22/big/img_62 +2002/08/13/big/img_466 +2002/08/21/big/img_956 +2002/08/23/big/img_602 +2002/08/20/big/img_858 +2002/07/25/big/img_690 +2002/07/19/big/img_130 +2002/08/04/big/img_874 +2002/07/26/big/img_489 +2002/07/22/big/img_548 +2002/08/10/big/img_191 +2002/07/25/big/img_1051 +2002/08/18/big/img_473 +2002/08/12/big/img_755 +2002/08/18/big/img_413 +2002/08/08/big/img_1044 +2002/08/17/big/img_680 +2002/08/26/big/img_235 +2002/08/20/big/img_330 +2002/08/22/big/img_344 +2002/08/09/big/img_593 +2002/07/31/big/img_1006 +2002/08/14/big/img_337 +2002/08/16/big/img_728 +2002/07/24/big/img_834 +2002/08/04/big/img_552 +2002/09/02/big/img_15213 +2002/07/25/big/img_725 +2002/08/30/big/img_18290 +2003/01/01/big/img_475 +2002/07/27/big/img_1083 +2002/08/29/big/img_18955 +2002/08/31/big/img_17232 +2002/08/08/big/img_480 +2002/08/01/big/img_1311 +2002/07/30/big/img_745 +2002/08/03/big/img_649 +2002/08/12/big/img_193 +2002/07/29/big/img_228 +2002/07/25/big/img_836 +2002/08/20/big/img_400 +2002/07/30/big/img_507 +2002/09/02/big/img_15072 +2002/07/26/big/img_658 +2002/07/28/big/img_503 +2002/08/05/big/img_3814 +2002/08/24/big/img_745 +2003/01/13/big/img_817 +2002/08/08/big/img_579 +2002/07/22/big/img_251 +2003/01/13/big/img_689 +2002/07/25/big/img_407 +2002/08/13/big/img_1050 +2002/08/14/big/img_733 +2002/07/24/big/img_82 +2003/01/17/big/img_288 +2003/01/15/big/img_475 +2002/08/14/big/img_620 +2002/08/21/big/img_167 +2002/07/19/big/img_300 +2002/07/26/big/img_219 +2002/08/01/big/img_1468 +2002/07/23/big/img_260 +2002/08/09/big/img_555 +2002/07/19/big/img_160 +2002/08/02/big/img_1060 +2003/01/14/big/img_149 +2002/08/15/big/img_346 +2002/08/24/big/img_597 +2002/08/22/big/img_502 +2002/08/30/big/img_18228 +2002/07/21/big/img_766 +2003/01/15/big/img_841 +2002/07/24/big/img_516 +2002/08/02/big/img_265 +2002/08/15/big/img_1243 +2003/01/15/big/img_223 +2002/08/04/big/img_236 +2002/07/22/big/img_309 +2002/07/20/big/img_656 +2002/07/31/big/img_412 +2002/09/01/big/img_16462 +2003/01/16/big/img_431 +2002/07/22/big/img_793 +2002/08/15/big/img_877 +2002/07/26/big/img_282 +2002/07/25/big/img_529 +2002/08/24/big/img_613 +2003/01/17/big/img_700 +2002/08/06/big/img_2526 +2002/08/24/big/img_394 +2002/08/21/big/img_521 +2002/08/25/big/img_560 +2002/07/29/big/img_966 +2002/07/25/big/img_448 +2003/01/13/big/img_782 +2002/08/21/big/img_296 +2002/09/01/big/img_16755 +2002/08/05/big/img_3552 +2002/09/02/big/img_15823 +2003/01/14/big/img_193 +2002/07/21/big/img_159 +2002/08/02/big/img_564 +2002/08/16/big/img_300 +2002/07/19/big/img_269 +2002/08/13/big/img_676 +2002/07/28/big/img_57 +2002/08/05/big/img_3318 +2002/07/31/big/img_218 +2002/08/21/big/img_898 +2002/07/29/big/img_109 +2002/07/19/big/img_854 +2002/08/23/big/img_311 +2002/08/14/big/img_318 +2002/07/25/big/img_523 +2002/07/21/big/img_678 +2003/01/17/big/img_690 +2002/08/28/big/img_19503 +2002/08/18/big/img_251 +2002/08/22/big/img_672 +2002/08/20/big/img_663 +2002/08/02/big/img_148 +2002/09/02/big/img_15580 +2002/07/25/big/img_778 +2002/08/14/big/img_565 +2002/08/12/big/img_374 +2002/08/13/big/img_1018 +2002/08/20/big/img_474 +2002/08/25/big/img_33 +2002/08/02/big/img_1190 +2002/08/08/big/img_864 +2002/08/14/big/img_1071 +2002/08/30/big/img_18103 +2002/08/18/big/img_533 +2003/01/16/big/img_650 +2002/07/25/big/img_108 +2002/07/26/big/img_81 +2002/07/27/big/img_543 +2002/07/29/big/img_521 +2003/01/13/big/img_434 +2002/08/26/big/img_674 +2002/08/06/big/img_2932 +2002/08/07/big/img_1262 +2003/01/15/big/img_201 +2003/01/16/big/img_673 +2002/09/02/big/img_15988 +2002/07/29/big/img_1306 +2003/01/14/big/img_1072 +2002/08/30/big/img_18232 +2002/08/05/big/img_3711 +2002/07/23/big/img_775 +2002/08/01/big/img_16 +2003/01/16/big/img_630 +2002/08/22/big/img_695 +2002/08/14/big/img_51 +2002/08/14/big/img_782 +2002/08/24/big/img_742 +2003/01/14/big/img_512 +2003/01/15/big/img_1183 +2003/01/15/big/img_714 +2002/08/01/big/img_2078 +2002/07/31/big/img_682 +2002/09/02/big/img_15687 +2002/07/26/big/img_518 +2002/08/27/big/img_19676 +2002/09/02/big/img_15969 +2002/08/02/big/img_931 +2002/08/25/big/img_508 +2002/08/29/big/img_18616 +2002/07/22/big/img_839 +2002/07/28/big/img_313 +2003/01/14/big/img_155 +2002/08/02/big/img_1105 +2002/08/09/big/img_53 +2002/08/16/big/img_469 +2002/08/15/big/img_502 +2002/08/20/big/img_575 +2002/07/25/big/img_138 +2003/01/16/big/img_579 +2002/07/19/big/img_352 +2003/01/14/big/img_762 +2003/01/01/big/img_588 +2002/08/02/big/img_981 +2002/08/21/big/img_447 +2002/09/01/big/img_16151 +2003/01/14/big/img_769 +2002/08/23/big/img_461 +2002/08/17/big/img_240 +2002/09/02/big/img_15220 +2002/07/19/big/img_408 +2002/09/02/big/img_15496 +2002/07/29/big/img_758 +2002/08/28/big/img_19392 +2002/08/06/big/img_2723 +2002/08/31/big/img_17752 +2002/08/23/big/img_469 +2002/08/13/big/img_515 +2002/09/02/big/img_15551 +2002/08/03/big/img_462 +2002/07/24/big/img_613 +2002/07/22/big/img_61 +2002/08/08/big/img_171 +2002/08/21/big/img_177 +2003/01/14/big/img_105 +2002/08/02/big/img_1017 +2002/08/22/big/img_106 +2002/07/27/big/img_542 +2002/07/21/big/img_665 +2002/07/23/big/img_595 +2002/08/04/big/img_657 +2002/08/29/big/img_19002 +2003/01/15/big/img_550 +2002/08/14/big/img_662 +2002/07/20/big/img_425 +2002/08/30/big/img_18528 +2002/07/26/big/img_611 +2002/07/22/big/img_849 +2002/08/07/big/img_1655 +2002/08/21/big/img_638 +2003/01/17/big/img_732 +2003/01/01/big/img_496 +2002/08/18/big/img_713 +2002/08/08/big/img_109 +2002/07/27/big/img_1008 +2002/07/20/big/img_559 +2002/08/16/big/img_699 +2002/08/31/big/img_17702 +2002/07/31/big/img_1013 +2002/08/01/big/img_2027 +2002/08/02/big/img_1001 +2002/08/03/big/img_210 +2002/08/01/big/img_2087 +2003/01/14/big/img_199 +2002/07/29/big/img_48 +2002/07/19/big/img_727 +2002/08/09/big/img_249 +2002/08/04/big/img_632 +2002/08/22/big/img_620 +2003/01/01/big/img_457 +2002/08/05/big/img_3223 +2002/07/27/big/img_240 +2002/07/25/big/img_797 +2002/08/13/big/img_430 +2002/07/25/big/img_615 +2002/08/12/big/img_28 +2002/07/30/big/img_220 +2002/07/24/big/img_89 +2002/08/21/big/img_357 +2002/08/09/big/img_590 +2003/01/13/big/img_525 +2002/08/17/big/img_818 +2003/01/02/big/img_7 +2002/07/26/big/img_636 +2003/01/13/big/img_1122 +2002/07/23/big/img_810 +2002/08/20/big/img_888 +2002/07/27/big/img_3 +2002/08/15/big/img_451 +2002/09/02/big/img_15787 +2002/07/31/big/img_281 +2002/08/05/big/img_3274 +2002/08/07/big/img_1254 +2002/07/31/big/img_27 +2002/08/01/big/img_1366 +2002/07/30/big/img_182 +2002/08/27/big/img_19690 +2002/07/29/big/img_68 +2002/08/23/big/img_754 +2002/07/30/big/img_540 +2002/08/27/big/img_20063 +2002/08/14/big/img_471 +2002/08/02/big/img_615 +2002/07/30/big/img_186 +2002/08/25/big/img_150 +2002/07/27/big/img_626 +2002/07/20/big/img_225 +2003/01/15/big/img_1252 +2002/07/19/big/img_367 +2003/01/15/big/img_582 +2002/08/09/big/img_572 +2002/08/08/big/img_428 +2003/01/15/big/img_639 +2002/08/28/big/img_19245 +2002/07/24/big/img_321 +2002/08/02/big/img_662 +2002/08/08/big/img_1033 +2003/01/17/big/img_867 +2002/07/22/big/img_652 +2003/01/14/big/img_224 +2002/08/18/big/img_49 +2002/07/26/big/img_46 +2002/08/31/big/img_18021 +2002/07/25/big/img_151 +2002/08/23/big/img_540 +2002/08/25/big/img_693 +2002/07/23/big/img_340 +2002/07/28/big/img_117 +2002/09/02/big/img_15768 +2002/08/26/big/img_562 +2002/07/24/big/img_480 +2003/01/15/big/img_341 +2002/08/10/big/img_783 +2002/08/20/big/img_132 +2003/01/14/big/img_370 +2002/07/20/big/img_720 +2002/08/03/big/img_144 +2002/08/20/big/img_538 +2002/08/01/big/img_1745 +2002/08/11/big/img_683 +2002/08/03/big/img_328 +2002/08/10/big/img_793 +2002/08/14/big/img_689 +2002/08/02/big/img_162 +2003/01/17/big/img_411 +2002/07/31/big/img_361 +2002/08/15/big/img_289 +2002/08/08/big/img_254 +2002/08/15/big/img_996 +2002/08/20/big/img_785 +2002/07/24/big/img_511 +2002/08/06/big/img_2614 +2002/08/29/big/img_18733 +2002/08/17/big/img_78 +2002/07/30/big/img_378 +2002/08/31/big/img_17947 +2002/08/26/big/img_88 +2002/07/30/big/img_558 +2002/08/02/big/img_67 +2003/01/14/big/img_325 +2002/07/29/big/img_1357 +2002/07/19/big/img_391 +2002/07/30/big/img_307 +2003/01/13/big/img_219 +2002/07/24/big/img_807 +2002/08/23/big/img_543 +2002/08/29/big/img_18620 +2002/07/22/big/img_769 +2002/08/26/big/img_503 +2002/07/30/big/img_78 +2002/08/14/big/img_1036 +2002/08/09/big/img_58 +2002/07/24/big/img_616 +2002/08/02/big/img_464 +2002/07/26/big/img_576 +2002/07/22/big/img_273 +2003/01/16/big/img_470 +2002/07/29/big/img_329 +2002/07/30/big/img_1086 +2002/07/31/big/img_353 +2002/09/02/big/img_15275 +2003/01/17/big/img_555 +2002/08/26/big/img_212 +2002/08/01/big/img_1692 +2003/01/15/big/img_600 +2002/07/29/big/img_825 +2002/08/08/big/img_68 +2002/08/10/big/img_719 +2002/07/31/big/img_636 +2002/07/29/big/img_325 +2002/07/21/big/img_515 +2002/07/22/big/img_705 +2003/01/13/big/img_818 +2002/08/09/big/img_486 +2002/08/22/big/img_141 +2002/07/22/big/img_303 +2002/08/09/big/img_393 +2002/07/29/big/img_963 +2002/08/02/big/img_1215 +2002/08/19/big/img_674 +2002/08/12/big/img_690 +2002/08/21/big/img_637 +2002/08/21/big/img_841 +2002/08/24/big/img_71 +2002/07/25/big/img_596 +2002/07/24/big/img_864 +2002/08/18/big/img_293 +2003/01/14/big/img_657 +2002/08/15/big/img_411 +2002/08/16/big/img_348 +2002/08/05/big/img_3157 +2002/07/20/big/img_663 +2003/01/13/big/img_654 +2003/01/16/big/img_433 +2002/08/30/big/img_18200 +2002/08/12/big/img_226 +2003/01/16/big/img_491 +2002/08/08/big/img_666 +2002/07/19/big/img_576 +2003/01/15/big/img_776 +2003/01/16/big/img_899 +2002/07/19/big/img_397 +2002/08/14/big/img_44 +2003/01/15/big/img_762 +2002/08/02/big/img_982 +2002/09/02/big/img_15234 +2002/08/17/big/img_556 +2002/08/21/big/img_410 +2002/08/21/big/img_386 +2002/07/19/big/img_690 +2002/08/05/big/img_3052 +2002/08/14/big/img_219 +2002/08/16/big/img_273 +2003/01/15/big/img_752 +2002/08/08/big/img_184 +2002/07/31/big/img_743 +2002/08/23/big/img_338 +2003/01/14/big/img_1055 +2002/08/05/big/img_3405 +2003/01/15/big/img_17 +2002/08/03/big/img_141 +2002/08/14/big/img_549 +2002/07/27/big/img_1034 +2002/07/31/big/img_932 +2002/08/30/big/img_18487 +2002/09/02/big/img_15814 +2002/08/01/big/img_2086 +2002/09/01/big/img_16535 +2002/07/22/big/img_500 +2003/01/13/big/img_400 +2002/08/25/big/img_607 +2002/08/30/big/img_18384 +2003/01/14/big/img_951 +2002/08/13/big/img_1150 +2002/08/08/big/img_1022 +2002/08/10/big/img_428 +2002/08/28/big/img_19242 +2002/08/05/big/img_3098 +2002/07/23/big/img_400 +2002/08/26/big/img_365 +2002/07/20/big/img_318 +2002/08/13/big/img_740 +2003/01/16/big/img_37 +2002/08/26/big/img_274 +2002/08/02/big/img_205 +2002/08/21/big/img_695 +2002/08/06/big/img_2289 +2002/08/20/big/img_794 +2002/08/18/big/img_438 +2002/08/07/big/img_1380 +2002/08/02/big/img_737 +2002/08/07/big/img_1651 +2002/08/15/big/img_1238 +2002/08/01/big/img_1681 +2002/08/06/big/img_3017 +2002/07/23/big/img_706 +2002/07/31/big/img_392 +2002/08/09/big/img_539 +2002/07/29/big/img_835 +2002/08/26/big/img_723 +2002/08/28/big/img_19235 +2003/01/16/big/img_353 +2002/08/10/big/img_150 +2002/08/29/big/img_19025 +2002/08/21/big/img_310 +2002/08/10/big/img_823 +2002/07/26/big/img_981 +2002/08/11/big/img_288 +2002/08/19/big/img_534 +2002/08/21/big/img_300 +2002/07/31/big/img_49 +2002/07/30/big/img_469 +2002/08/28/big/img_19197 +2002/08/25/big/img_205 +2002/08/10/big/img_390 +2002/08/23/big/img_291 +2002/08/26/big/img_230 +2002/08/18/big/img_76 +2002/07/23/big/img_409 +2002/08/14/big/img_1053 +2003/01/14/big/img_291 +2002/08/10/big/img_503 +2002/08/27/big/img_19928 +2002/08/03/big/img_563 +2002/08/17/big/img_250 +2002/08/06/big/img_2381 +2002/08/17/big/img_948 +2002/08/06/big/img_2710 +2002/07/22/big/img_696 +2002/07/31/big/img_670 +2002/08/12/big/img_594 +2002/07/29/big/img_624 +2003/01/17/big/img_934 +2002/08/03/big/img_584 +2002/08/22/big/img_1003 +2002/08/05/big/img_3396 +2003/01/13/big/img_570 +2002/08/02/big/img_219 +2002/09/02/big/img_15774 +2002/08/16/big/img_818 +2002/08/23/big/img_402 +2003/01/14/big/img_552 +2002/07/29/big/img_71 +2002/08/05/big/img_3592 +2002/08/16/big/img_80 +2002/07/27/big/img_672 +2003/01/13/big/img_470 +2003/01/16/big/img_702 +2002/09/01/big/img_16130 +2002/08/08/big/img_240 +2002/09/01/big/img_16338 +2002/07/26/big/img_312 +2003/01/14/big/img_538 +2002/07/20/big/img_695 +2002/08/30/big/img_18098 +2002/08/25/big/img_259 +2002/08/16/big/img_1042 +2002/08/09/big/img_837 +2002/08/31/big/img_17760 +2002/07/31/big/img_14 +2002/08/09/big/img_361 +2003/01/16/big/img_107 +2002/08/14/big/img_124 +2002/07/19/big/img_463 +2003/01/15/big/img_275 +2002/07/25/big/img_1151 +2002/07/29/big/img_1501 +2002/08/27/big/img_19889 +2002/08/29/big/img_18603 +2003/01/17/big/img_601 +2002/08/25/big/img_355 +2002/08/08/big/img_297 +2002/08/20/big/img_290 +2002/07/31/big/img_195 +2003/01/01/big/img_336 +2002/08/18/big/img_369 +2002/07/25/big/img_621 +2002/08/11/big/img_508 +2003/01/14/big/img_458 +2003/01/15/big/img_795 +2002/08/12/big/img_498 +2002/08/01/big/img_1734 +2002/08/02/big/img_246 +2002/08/16/big/img_565 +2002/08/11/big/img_475 +2002/08/22/big/img_408 +2002/07/28/big/img_78 +2002/07/21/big/img_81 +2003/01/14/big/img_697 +2002/08/14/big/img_661 +2002/08/15/big/img_507 +2002/08/19/big/img_55 +2002/07/22/big/img_152 +2003/01/14/big/img_470 +2002/08/03/big/img_379 +2002/08/22/big/img_506 +2003/01/16/big/img_966 +2002/08/18/big/img_698 +2002/08/24/big/img_528 +2002/08/23/big/img_10 +2002/08/01/big/img_1655 +2002/08/22/big/img_953 +2002/07/19/big/img_630 +2002/07/22/big/img_889 +2002/08/16/big/img_351 +2003/01/16/big/img_83 +2002/07/19/big/img_805 +2002/08/14/big/img_704 +2002/07/19/big/img_389 +2002/08/31/big/img_17765 +2002/07/29/big/img_606 +2003/01/17/big/img_939 +2002/09/02/big/img_15081 +2002/08/21/big/img_181 +2002/07/29/big/img_1321 +2002/07/21/big/img_497 +2002/07/20/big/img_539 +2002/08/24/big/img_119 +2002/08/01/big/img_1281 +2002/07/26/big/img_207 +2002/07/26/big/img_432 +2002/07/27/big/img_1006 +2002/08/05/big/img_3087 +2002/08/14/big/img_252 +2002/08/14/big/img_798 +2002/07/24/big/img_538 +2002/09/02/big/img_15507 +2002/08/08/big/img_901 +2003/01/14/big/img_557 +2002/08/07/big/img_1819 +2002/08/04/big/img_470 +2002/08/01/big/img_1504 +2002/08/16/big/img_1070 +2002/08/16/big/img_372 +2002/08/23/big/img_416 +2002/08/30/big/img_18208 +2002/08/01/big/img_2043 +2002/07/22/big/img_385 +2002/08/22/big/img_466 +2002/08/21/big/img_869 +2002/08/28/big/img_19429 +2002/08/02/big/img_770 +2002/07/23/big/img_433 +2003/01/14/big/img_13 +2002/07/27/big/img_953 +2002/09/02/big/img_15728 +2002/08/01/big/img_1361 +2002/08/29/big/img_18897 +2002/08/26/big/img_534 +2002/08/11/big/img_121 +2002/08/26/big/img_20130 +2002/07/31/big/img_363 +2002/08/13/big/img_978 +2002/07/25/big/img_835 +2002/08/02/big/img_906 +2003/01/14/big/img_548 +2002/07/30/big/img_80 +2002/07/26/big/img_982 +2003/01/16/big/img_99 +2002/08/19/big/img_362 +2002/08/24/big/img_376 +2002/08/07/big/img_1264 +2002/07/27/big/img_938 +2003/01/17/big/img_535 +2002/07/26/big/img_457 +2002/08/08/big/img_848 +2003/01/15/big/img_859 +2003/01/15/big/img_622 +2002/07/30/big/img_403 +2002/07/29/big/img_217 +2002/07/26/big/img_891 +2002/07/24/big/img_70 +2002/08/25/big/img_619 +2002/08/05/big/img_3375 +2002/08/01/big/img_2160 +2002/08/06/big/img_2227 +2003/01/14/big/img_117 +2002/08/14/big/img_227 +2002/08/13/big/img_565 +2002/08/19/big/img_625 +2002/08/03/big/img_812 +2002/07/24/big/img_41 +2002/08/16/big/img_235 +2002/07/29/big/img_759 +2002/07/21/big/img_433 +2002/07/29/big/img_190 +2003/01/16/big/img_435 +2003/01/13/big/img_708 +2002/07/30/big/img_57 +2002/08/22/big/img_162 +2003/01/01/big/img_558 +2003/01/15/big/img_604 +2002/08/16/big/img_935 +2002/08/20/big/img_394 +2002/07/28/big/img_465 +2002/09/02/big/img_15534 +2002/08/16/big/img_87 +2002/07/22/big/img_469 +2002/08/12/big/img_245 +2003/01/13/big/img_236 +2002/08/06/big/img_2736 +2002/08/03/big/img_348 +2003/01/14/big/img_218 +2002/07/26/big/img_232 +2003/01/15/big/img_244 +2002/07/25/big/img_1121 +2002/08/01/big/img_1484 +2002/07/26/big/img_541 +2002/08/07/big/img_1244 +2002/07/31/big/img_3 +2002/08/30/big/img_18437 +2002/08/29/big/img_19094 +2002/08/01/big/img_1355 +2002/08/19/big/img_338 +2002/07/19/big/img_255 +2002/07/21/big/img_76 +2002/08/25/big/img_199 +2002/08/12/big/img_740 +2002/07/30/big/img_852 +2002/08/15/big/img_599 +2002/08/23/big/img_254 +2002/08/19/big/img_125 +2002/07/24/big/img_2 +2002/08/04/big/img_145 +2002/08/05/big/img_3137 +2002/07/28/big/img_463 +2003/01/14/big/img_801 +2002/07/23/big/img_366 +2002/08/26/big/img_600 +2002/08/26/big/img_649 +2002/09/02/big/img_15849 +2002/07/26/big/img_248 +2003/01/13/big/img_200 +2002/08/07/big/img_1794 +2002/08/31/big/img_17270 +2002/08/23/big/img_608 +2003/01/13/big/img_837 +2002/08/23/big/img_581 +2002/08/20/big/img_754 +2002/08/18/big/img_183 +2002/08/20/big/img_328 +2002/07/22/big/img_494 +2002/07/29/big/img_399 +2002/08/28/big/img_19284 +2002/08/08/big/img_566 +2002/07/25/big/img_376 +2002/07/23/big/img_138 +2002/07/25/big/img_435 +2002/08/17/big/img_685 +2002/07/19/big/img_90 +2002/07/20/big/img_716 +2002/08/31/big/img_17458 +2002/08/26/big/img_461 +2002/07/25/big/img_355 +2002/08/06/big/img_2152 +2002/07/27/big/img_932 +2002/07/23/big/img_232 +2002/08/08/big/img_1020 +2002/07/31/big/img_366 +2002/08/06/big/img_2667 +2002/08/21/big/img_465 +2002/08/15/big/img_305 +2002/08/02/big/img_247 +2002/07/28/big/img_46 +2002/08/27/big/img_19922 +2002/08/23/big/img_643 +2003/01/13/big/img_624 +2002/08/23/big/img_625 +2002/08/05/big/img_3787 +2003/01/13/big/img_627 +2002/09/01/big/img_16381 +2002/08/05/big/img_3668 +2002/07/21/big/img_535 +2002/08/27/big/img_19680 +2002/07/22/big/img_413 +2002/07/29/big/img_481 +2003/01/15/big/img_496 +2002/07/23/big/img_701 +2002/08/29/big/img_18670 +2002/07/28/big/img_319 +2003/01/14/big/img_517 +2002/07/26/big/img_256 +2003/01/16/big/img_593 +2002/07/30/big/img_956 +2002/07/30/big/img_667 +2002/07/25/big/img_100 +2002/08/11/big/img_570 +2002/07/26/big/img_745 +2002/08/04/big/img_834 +2002/08/25/big/img_521 +2002/08/01/big/img_2148 +2002/09/02/big/img_15183 +2002/08/22/big/img_514 +2002/08/23/big/img_477 +2002/07/23/big/img_336 +2002/07/26/big/img_481 +2002/08/20/big/img_409 +2002/07/23/big/img_918 +2002/08/09/big/img_474 +2002/08/02/big/img_929 +2002/08/31/big/img_17932 +2002/08/19/big/img_161 +2002/08/09/big/img_667 +2002/07/31/big/img_805 +2002/09/02/big/img_15678 +2002/08/31/big/img_17509 +2002/08/29/big/img_18998 +2002/07/23/big/img_301 +2002/08/07/big/img_1612 +2002/08/06/big/img_2472 +2002/07/23/big/img_466 +2002/08/27/big/img_19634 +2003/01/16/big/img_16 +2002/08/14/big/img_193 +2002/08/21/big/img_340 +2002/08/27/big/img_19799 +2002/08/01/big/img_1345 +2002/08/07/big/img_1448 +2002/08/11/big/img_324 +2003/01/16/big/img_754 +2002/08/13/big/img_418 +2003/01/16/big/img_544 +2002/08/19/big/img_135 +2002/08/10/big/img_455 +2002/08/10/big/img_693 +2002/08/31/big/img_17967 +2002/08/28/big/img_19229 +2002/08/04/big/img_811 +2002/09/01/big/img_16225 +2003/01/16/big/img_428 +2002/09/02/big/img_15295 +2002/07/26/big/img_108 +2002/07/21/big/img_477 +2002/08/07/big/img_1354 +2002/08/23/big/img_246 +2002/08/16/big/img_652 +2002/07/27/big/img_553 +2002/07/31/big/img_346 +2002/08/04/big/img_537 +2002/08/08/big/img_498 +2002/08/29/big/img_18956 +2003/01/13/big/img_922 +2002/08/31/big/img_17425 +2002/07/26/big/img_438 +2002/08/19/big/img_185 +2003/01/16/big/img_33 +2002/08/10/big/img_252 +2002/07/29/big/img_598 +2002/08/27/big/img_19820 +2002/08/06/big/img_2664 +2002/08/20/big/img_705 +2003/01/14/big/img_816 +2002/08/03/big/img_552 +2002/07/25/big/img_561 +2002/07/25/big/img_934 +2002/08/01/big/img_1893 +2003/01/14/big/img_746 +2003/01/16/big/img_519 +2002/08/03/big/img_681 +2002/07/24/big/img_808 +2002/08/14/big/img_803 +2002/08/25/big/img_155 +2002/07/30/big/img_1107 +2002/08/29/big/img_18882 +2003/01/15/big/img_598 +2002/08/19/big/img_122 +2002/07/30/big/img_428 +2002/07/24/big/img_684 +2002/08/22/big/img_192 +2002/08/22/big/img_543 +2002/08/07/big/img_1318 +2002/08/18/big/img_25 +2002/07/26/big/img_583 +2002/07/20/big/img_464 +2002/08/19/big/img_664 +2002/08/24/big/img_861 +2002/09/01/big/img_16136 +2002/08/22/big/img_400 +2002/08/12/big/img_445 +2003/01/14/big/img_174 +2002/08/27/big/img_19677 +2002/08/31/big/img_17214 +2002/08/30/big/img_18175 +2003/01/17/big/img_402 +2002/08/06/big/img_2396 +2002/08/18/big/img_448 +2002/08/21/big/img_165 +2002/08/31/big/img_17609 +2003/01/01/big/img_151 +2002/08/26/big/img_372 +2002/09/02/big/img_15994 +2002/07/26/big/img_660 +2002/09/02/big/img_15197 +2002/07/29/big/img_258 +2002/08/30/big/img_18525 +2003/01/13/big/img_368 +2002/07/29/big/img_1538 +2002/07/21/big/img_787 +2002/08/18/big/img_152 +2002/08/06/big/img_2379 +2003/01/17/big/img_864 +2002/08/27/big/img_19998 +2002/08/01/big/img_1634 +2002/07/25/big/img_414 +2002/08/22/big/img_627 +2002/08/07/big/img_1669 +2002/08/16/big/img_1052 +2002/08/31/big/img_17796 +2002/08/18/big/img_199 +2002/09/02/big/img_15147 +2002/08/09/big/img_460 +2002/08/14/big/img_581 +2002/08/30/big/img_18286 +2002/07/26/big/img_337 +2002/08/18/big/img_589 +2003/01/14/big/img_866 +2002/07/20/big/img_624 +2002/08/01/big/img_1801 +2002/07/24/big/img_683 +2002/08/09/big/img_725 +2003/01/14/big/img_34 +2002/07/30/big/img_144 +2002/07/30/big/img_706 +2002/08/08/big/img_394 +2002/08/19/big/img_619 +2002/08/06/big/img_2703 +2002/08/29/big/img_19034 +2002/07/24/big/img_67 +2002/08/27/big/img_19841 +2002/08/19/big/img_427 +2003/01/14/big/img_333 +2002/09/01/big/img_16406 +2002/07/19/big/img_882 +2002/08/17/big/img_238 +2003/01/14/big/img_739 +2002/07/22/big/img_151 +2002/08/21/big/img_743 +2002/07/25/big/img_1048 +2002/07/30/big/img_395 +2003/01/13/big/img_584 +2002/08/13/big/img_742 +2002/08/13/big/img_1168 +2003/01/14/big/img_147 +2002/07/26/big/img_803 +2002/08/05/big/img_3298 +2002/08/07/big/img_1451 +2002/08/16/big/img_424 +2002/07/29/big/img_1069 +2002/09/01/big/img_16735 +2002/07/21/big/img_637 +2003/01/14/big/img_585 +2002/08/02/big/img_358 +2003/01/13/big/img_358 +2002/08/14/big/img_198 +2002/08/17/big/img_935 +2002/08/04/big/img_42 +2002/08/30/big/img_18245 +2002/07/25/big/img_158 +2002/08/22/big/img_744 +2002/08/06/big/img_2291 +2002/08/05/big/img_3044 +2002/07/30/big/img_272 +2002/08/23/big/img_641 +2002/07/24/big/img_797 +2002/07/30/big/img_392 +2003/01/14/big/img_447 +2002/07/31/big/img_898 +2002/08/06/big/img_2812 +2002/08/13/big/img_564 +2002/07/22/big/img_43 +2002/07/26/big/img_634 +2002/07/19/big/img_843 +2002/08/26/big/img_58 +2002/07/21/big/img_375 +2002/08/25/big/img_729 +2002/07/19/big/img_561 +2003/01/15/big/img_884 +2002/07/25/big/img_891 +2002/08/09/big/img_558 +2002/08/26/big/img_587 +2002/08/13/big/img_1146 +2002/09/02/big/img_15153 +2002/07/26/big/img_316 +2002/08/01/big/img_1940 +2002/08/26/big/img_90 +2003/01/13/big/img_347 +2002/07/25/big/img_520 +2002/08/29/big/img_18718 +2002/08/28/big/img_19219 +2002/08/13/big/img_375 +2002/07/20/big/img_719 +2002/08/31/big/img_17431 +2002/07/28/big/img_192 +2002/08/26/big/img_259 +2002/08/18/big/img_484 +2002/07/29/big/img_580 +2002/07/26/big/img_84 +2002/08/02/big/img_302 +2002/08/31/big/img_17007 +2003/01/15/big/img_543 +2002/09/01/big/img_16488 +2002/08/22/big/img_798 +2002/07/30/big/img_383 +2002/08/04/big/img_668 +2002/08/13/big/img_156 +2002/08/07/big/img_1353 +2002/07/25/big/img_281 +2003/01/14/big/img_587 +2003/01/15/big/img_524 +2002/08/19/big/img_726 +2002/08/21/big/img_709 +2002/08/26/big/img_465 +2002/07/31/big/img_658 +2002/08/28/big/img_19148 +2002/07/23/big/img_423 +2002/08/16/big/img_758 +2002/08/22/big/img_523 +2002/08/16/big/img_591 +2002/08/23/big/img_845 +2002/07/26/big/img_678 +2002/08/09/big/img_806 +2002/08/06/big/img_2369 +2002/07/29/big/img_457 +2002/07/19/big/img_278 +2002/08/30/big/img_18107 +2002/07/26/big/img_444 +2002/08/20/big/img_278 +2002/08/26/big/img_92 +2002/08/26/big/img_257 +2002/07/25/big/img_266 +2002/08/05/big/img_3829 +2002/07/26/big/img_757 +2002/07/29/big/img_1536 +2002/08/09/big/img_472 +2003/01/17/big/img_480 +2002/08/28/big/img_19355 +2002/07/26/big/img_97 +2002/08/06/big/img_2503 +2002/07/19/big/img_254 +2002/08/01/big/img_1470 +2002/08/21/big/img_42 +2002/08/20/big/img_217 +2002/08/06/big/img_2459 +2002/07/19/big/img_552 +2002/08/13/big/img_717 +2002/08/12/big/img_586 +2002/08/20/big/img_411 +2003/01/13/big/img_768 +2002/08/07/big/img_1747 +2002/08/15/big/img_385 +2002/08/01/big/img_1648 +2002/08/15/big/img_311 +2002/08/21/big/img_95 +2002/08/09/big/img_108 +2002/08/21/big/img_398 +2002/08/17/big/img_340 +2002/08/14/big/img_474 +2002/08/13/big/img_294 +2002/08/24/big/img_840 +2002/08/09/big/img_808 +2002/08/23/big/img_491 +2002/07/28/big/img_33 +2003/01/13/big/img_664 +2002/08/02/big/img_261 +2002/08/09/big/img_591 +2002/07/26/big/img_309 +2003/01/14/big/img_372 +2002/08/19/big/img_581 +2002/08/19/big/img_168 +2002/08/26/big/img_422 +2002/07/24/big/img_106 +2002/08/01/big/img_1936 +2002/08/05/big/img_3764 +2002/08/21/big/img_266 +2002/08/31/big/img_17968 +2002/08/01/big/img_1941 +2002/08/15/big/img_550 +2002/08/14/big/img_13 +2002/07/30/big/img_171 +2003/01/13/big/img_490 +2002/07/25/big/img_427 +2002/07/19/big/img_770 +2002/08/12/big/img_759 +2003/01/15/big/img_1360 +2002/08/05/big/img_3692 +2003/01/16/big/img_30 +2002/07/25/big/img_1026 +2002/07/22/big/img_288 +2002/08/29/big/img_18801 +2002/07/24/big/img_793 +2002/08/13/big/img_178 +2002/08/06/big/img_2322 +2003/01/14/big/img_560 +2002/08/18/big/img_408 +2003/01/16/big/img_915 +2003/01/16/big/img_679 +2002/08/07/big/img_1552 +2002/08/29/big/img_19050 +2002/08/01/big/img_2172 +2002/07/31/big/img_30 +2002/07/30/big/img_1019 +2002/07/30/big/img_587 +2003/01/13/big/img_773 +2002/07/30/big/img_410 +2002/07/28/big/img_65 +2002/08/05/big/img_3138 +2002/07/23/big/img_541 +2002/08/22/big/img_963 +2002/07/27/big/img_657 +2002/07/30/big/img_1051 +2003/01/16/big/img_150 +2002/07/31/big/img_519 +2002/08/01/big/img_1961 +2002/08/05/big/img_3752 +2002/07/23/big/img_631 +2003/01/14/big/img_237 +2002/07/28/big/img_21 +2002/07/22/big/img_813 +2002/08/05/big/img_3563 +2003/01/17/big/img_620 +2002/07/19/big/img_523 +2002/07/30/big/img_904 +2002/08/29/big/img_18642 +2002/08/11/big/img_492 +2002/08/01/big/img_2130 +2002/07/25/big/img_618 +2002/08/17/big/img_305 +2003/01/16/big/img_520 +2002/07/26/big/img_495 +2002/08/17/big/img_164 +2002/08/03/big/img_440 +2002/07/24/big/img_441 +2002/08/06/big/img_2146 +2002/08/11/big/img_558 +2002/08/02/big/img_545 +2002/08/31/big/img_18090 +2003/01/01/big/img_136 +2002/07/25/big/img_1099 +2003/01/13/big/img_728 +2003/01/16/big/img_197 +2002/07/26/big/img_651 +2002/08/11/big/img_676 +2003/01/15/big/img_10 +2002/08/21/big/img_250 +2002/08/14/big/img_325 +2002/08/04/big/img_390 +2002/07/24/big/img_554 +2003/01/16/big/img_333 +2002/07/31/big/img_922 +2002/09/02/big/img_15586 +2003/01/16/big/img_184 +2002/07/22/big/img_766 +2002/07/21/big/img_608 +2002/08/07/big/img_1578 +2002/08/17/big/img_961 +2002/07/27/big/img_324 +2002/08/05/big/img_3765 +2002/08/23/big/img_462 +2003/01/16/big/img_382 +2002/08/27/big/img_19838 +2002/08/01/big/img_1505 +2002/08/21/big/img_662 +2002/08/14/big/img_605 +2002/08/19/big/img_816 +2002/07/29/big/img_136 +2002/08/20/big/img_719 +2002/08/06/big/img_2826 +2002/08/10/big/img_630 +2003/01/17/big/img_973 +2002/08/14/big/img_116 +2002/08/02/big/img_666 +2002/08/21/big/img_710 +2002/08/05/big/img_55 +2002/07/31/big/img_229 +2002/08/01/big/img_1549 +2002/07/23/big/img_432 +2002/07/21/big/img_430 +2002/08/21/big/img_549 +2002/08/08/big/img_985 +2002/07/20/big/img_610 +2002/07/23/big/img_978 +2002/08/23/big/img_219 +2002/07/25/big/img_175 +2003/01/15/big/img_230 +2002/08/23/big/img_385 +2002/07/31/big/img_879 +2002/08/12/big/img_495 +2002/08/22/big/img_499 +2002/08/30/big/img_18322 +2002/08/15/big/img_795 +2002/08/13/big/img_835 +2003/01/17/big/img_930 +2002/07/30/big/img_873 +2002/08/11/big/img_257 +2002/07/31/big/img_593 +2002/08/21/big/img_916 +2003/01/13/big/img_814 +2002/07/25/big/img_722 +2002/08/16/big/img_379 +2002/07/31/big/img_497 +2002/07/22/big/img_602 +2002/08/21/big/img_642 +2002/08/21/big/img_614 +2002/08/23/big/img_482 +2002/07/29/big/img_603 +2002/08/13/big/img_705 +2002/07/23/big/img_833 +2003/01/14/big/img_511 +2002/07/24/big/img_376 +2002/08/17/big/img_1030 +2002/08/05/big/img_3576 +2002/08/16/big/img_540 +2002/07/22/big/img_630 +2002/08/10/big/img_180 +2002/08/14/big/img_905 +2002/08/29/big/img_18777 +2002/08/22/big/img_693 +2003/01/16/big/img_933 +2002/08/20/big/img_555 +2002/08/15/big/img_549 +2003/01/14/big/img_830 +2003/01/16/big/img_64 +2002/08/27/big/img_19670 +2002/08/22/big/img_729 +2002/07/27/big/img_981 +2002/08/09/big/img_458 +2003/01/17/big/img_884 +2002/07/25/big/img_639 +2002/08/31/big/img_18008 +2002/08/22/big/img_249 +2002/08/17/big/img_971 +2002/08/04/big/img_308 +2002/07/28/big/img_362 +2002/08/12/big/img_142 +2002/08/26/big/img_61 +2002/08/14/big/img_422 +2002/07/19/big/img_607 +2003/01/15/big/img_717 +2002/08/01/big/img_1475 +2002/08/29/big/img_19061 +2003/01/01/big/img_346 +2002/07/20/big/img_315 +2003/01/15/big/img_756 +2002/08/15/big/img_879 +2002/08/08/big/img_615 +2003/01/13/big/img_431 +2002/08/05/big/img_3233 +2002/08/24/big/img_526 +2003/01/13/big/img_717 +2002/09/01/big/img_16408 +2002/07/22/big/img_217 +2002/07/31/big/img_960 +2002/08/21/big/img_610 +2002/08/05/big/img_3753 +2002/08/03/big/img_151 +2002/08/21/big/img_267 +2002/08/01/big/img_2175 +2002/08/04/big/img_556 +2002/08/21/big/img_527 +2002/09/02/big/img_15800 +2002/07/27/big/img_156 +2002/07/20/big/img_590 +2002/08/15/big/img_700 +2002/08/08/big/img_444 +2002/07/25/big/img_94 +2002/07/24/big/img_778 +2002/08/14/big/img_694 +2002/07/20/big/img_666 +2002/08/02/big/img_200 +2002/08/02/big/img_578 +2003/01/17/big/img_332 +2002/09/01/big/img_16352 +2002/08/27/big/img_19668 +2002/07/23/big/img_823 +2002/08/13/big/img_431 +2003/01/16/big/img_463 +2002/08/27/big/img_19711 +2002/08/23/big/img_154 +2002/07/31/big/img_360 +2002/08/23/big/img_555 +2002/08/10/big/img_561 +2003/01/14/big/img_550 +2002/08/07/big/img_1370 +2002/07/30/big/img_1184 +2002/08/01/big/img_1445 +2002/08/23/big/img_22 +2002/07/30/big/img_606 +2003/01/17/big/img_271 +2002/08/31/big/img_17316 +2002/08/16/big/img_973 +2002/07/26/big/img_77 +2002/07/20/big/img_788 +2002/08/06/big/img_2426 +2002/08/07/big/img_1498 +2002/08/16/big/img_358 +2002/08/06/big/img_2851 +2002/08/12/big/img_359 +2002/08/01/big/img_1521 +2002/08/02/big/img_709 +2002/08/20/big/img_935 +2002/08/12/big/img_188 +2002/08/24/big/img_411 +2002/08/22/big/img_680 +2002/08/06/big/img_2480 +2002/07/20/big/img_627 +2002/07/30/big/img_214 +2002/07/25/big/img_354 +2002/08/02/big/img_636 +2003/01/15/big/img_661 +2002/08/07/big/img_1327 +2002/08/01/big/img_2108 +2002/08/31/big/img_17919 +2002/08/29/big/img_18768 +2002/08/05/big/img_3840 +2002/07/26/big/img_242 +2003/01/14/big/img_451 +2002/08/20/big/img_923 +2002/08/27/big/img_19908 +2002/08/16/big/img_282 +2002/08/19/big/img_440 +2003/01/01/big/img_230 +2002/08/08/big/img_212 +2002/07/20/big/img_443 +2002/08/25/big/img_635 +2003/01/13/big/img_1169 +2002/07/26/big/img_998 +2002/08/15/big/img_995 +2002/08/06/big/img_3002 +2002/07/29/big/img_460 +2003/01/14/big/img_925 +2002/07/23/big/img_539 +2002/08/16/big/img_694 +2003/01/13/big/img_459 +2002/07/23/big/img_249 +2002/08/20/big/img_539 +2002/08/04/big/img_186 +2002/08/26/big/img_264 +2002/07/22/big/img_704 +2002/08/25/big/img_277 +2002/08/22/big/img_988 +2002/07/29/big/img_504 +2002/08/05/big/img_3600 +2002/08/30/big/img_18380 +2003/01/14/big/img_937 +2002/08/21/big/img_254 +2002/08/10/big/img_130 +2002/08/20/big/img_339 +2003/01/14/big/img_428 +2002/08/20/big/img_889 +2002/08/31/big/img_17637 +2002/07/26/big/img_644 +2002/09/01/big/img_16776 +2002/08/06/big/img_2239 +2002/08/06/big/img_2646 +2003/01/13/big/img_491 +2002/08/10/big/img_579 +2002/08/21/big/img_713 +2002/08/22/big/img_482 +2002/07/22/big/img_167 +2002/07/24/big/img_539 +2002/08/14/big/img_721 +2002/07/25/big/img_389 +2002/09/01/big/img_16591 +2002/08/13/big/img_543 +2003/01/14/big/img_432 +2002/08/09/big/img_287 +2002/07/26/big/img_126 +2002/08/23/big/img_412 +2002/08/15/big/img_1034 +2002/08/28/big/img_19485 +2002/07/31/big/img_236 +2002/07/30/big/img_523 +2002/07/19/big/img_141 +2003/01/17/big/img_957 +2002/08/04/big/img_81 +2002/07/25/big/img_206 +2002/08/15/big/img_716 +2002/08/13/big/img_403 +2002/08/15/big/img_685 +2002/07/26/big/img_884 +2002/07/19/big/img_499 +2002/07/23/big/img_772 +2002/07/27/big/img_752 +2003/01/14/big/img_493 +2002/08/25/big/img_664 +2002/07/31/big/img_334 +2002/08/26/big/img_678 +2002/09/01/big/img_16541 +2003/01/14/big/img_347 +2002/07/23/big/img_187 +2002/07/30/big/img_1163 +2002/08/05/big/img_35 +2002/08/22/big/img_944 +2002/08/07/big/img_1239 +2002/07/29/big/img_1215 +2002/08/03/big/img_312 +2002/08/05/big/img_3523 +2002/07/29/big/img_218 +2002/08/13/big/img_672 +2002/08/16/big/img_205 +2002/08/17/big/img_594 +2002/07/29/big/img_1411 +2002/07/30/big/img_942 +2003/01/16/big/img_312 +2002/08/08/big/img_312 +2002/07/25/big/img_15 +2002/08/09/big/img_839 +2002/08/01/big/img_2069 +2002/08/31/big/img_17512 +2002/08/01/big/img_3 +2002/07/31/big/img_320 +2003/01/15/big/img_1265 +2002/08/14/big/img_563 +2002/07/31/big/img_167 +2002/08/20/big/img_374 +2002/08/13/big/img_406 +2002/08/08/big/img_625 +2002/08/02/big/img_314 +2002/08/27/big/img_19964 +2002/09/01/big/img_16670 +2002/07/31/big/img_599 +2002/08/29/big/img_18906 +2002/07/24/big/img_373 +2002/07/26/big/img_513 +2002/09/02/big/img_15497 +2002/08/19/big/img_117 +2003/01/01/big/img_158 +2002/08/24/big/img_178 +2003/01/13/big/img_935 +2002/08/13/big/img_609 +2002/08/30/big/img_18341 +2002/08/25/big/img_674 +2003/01/13/big/img_209 +2002/08/13/big/img_258 +2002/08/05/big/img_3543 +2002/08/07/big/img_1970 +2002/08/06/big/img_3004 +2003/01/17/big/img_487 +2002/08/24/big/img_873 +2002/08/29/big/img_18730 +2002/08/09/big/img_375 +2003/01/16/big/img_751 +2002/08/02/big/img_603 +2002/08/19/big/img_325 +2002/09/01/big/img_16420 +2002/08/05/big/img_3633 +2002/08/21/big/img_516 +2002/07/19/big/img_501 +2002/07/26/big/img_688 +2002/07/24/big/img_256 +2002/07/25/big/img_438 +2002/07/31/big/img_1017 +2002/08/22/big/img_512 +2002/07/21/big/img_543 +2002/08/08/big/img_223 +2002/08/19/big/img_189 +2002/08/12/big/img_630 +2002/07/30/big/img_958 +2002/07/28/big/img_208 +2002/08/31/big/img_17691 +2002/07/22/big/img_542 +2002/07/19/big/img_741 +2002/07/19/big/img_158 +2002/08/15/big/img_399 +2002/08/01/big/img_2159 +2002/08/14/big/img_455 +2002/08/17/big/img_1011 +2002/08/26/big/img_744 +2002/08/12/big/img_624 +2003/01/17/big/img_821 +2002/08/16/big/img_980 +2002/07/28/big/img_281 +2002/07/25/big/img_171 +2002/08/03/big/img_116 +2002/07/22/big/img_467 +2002/07/31/big/img_750 +2002/07/26/big/img_435 +2002/07/19/big/img_822 +2002/08/13/big/img_626 +2002/08/11/big/img_344 +2002/08/02/big/img_473 +2002/09/01/big/img_16817 +2002/08/01/big/img_1275 +2002/08/28/big/img_19270 +2002/07/23/big/img_607 +2002/08/09/big/img_316 +2002/07/29/big/img_626 +2002/07/24/big/img_824 +2002/07/22/big/img_342 +2002/08/08/big/img_794 +2002/08/07/big/img_1209 +2002/07/19/big/img_18 +2002/08/25/big/img_634 +2002/07/24/big/img_730 +2003/01/17/big/img_356 +2002/07/23/big/img_305 +2002/07/30/big/img_453 +2003/01/13/big/img_972 +2002/08/06/big/img_2610 +2002/08/29/big/img_18920 +2002/07/31/big/img_123 +2002/07/26/big/img_979 +2002/08/24/big/img_635 +2002/08/05/big/img_3704 +2002/08/07/big/img_1358 +2002/07/22/big/img_306 +2002/08/13/big/img_619 +2002/08/02/big/img_366 diff --git a/face_vid2vid/GPEN/retinaface/data/__init__.py b/face_vid2vid/GPEN/retinaface/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea50ebaf88d64e75f4960bc99b14f138a343e575 --- /dev/null +++ b/face_vid2vid/GPEN/retinaface/data/__init__.py @@ -0,0 +1,3 @@ +from .wider_face import WiderFaceDetection, detection_collate +from .data_augment import * +from .config import * diff --git a/face_vid2vid/GPEN/retinaface/data/config.py b/face_vid2vid/GPEN/retinaface/data/config.py new file mode 100644 index 0000000000000000000000000000000000000000..e57cdc530e3d78c4aa6310985c90c5ee125f8f01 --- /dev/null +++ b/face_vid2vid/GPEN/retinaface/data/config.py @@ -0,0 +1,42 @@ +# config.py + +cfg_mnet = { + 'name': 'mobilenet0.25', + 'min_sizes': [[16, 32], [64, 128], [256, 512]], + 'steps': [8, 16, 32], + 'variance': [0.1, 0.2], + 'clip': False, + 'loc_weight': 2.0, + 'gpu_train': True, + 'batch_size': 32, + 'ngpu': 1, + 'epoch': 250, + 'decay1': 190, + 'decay2': 220, + 'image_size': 640, + 'pretrain': False, + 'return_layers': {'stage1': 1, 'stage2': 2, 'stage3': 3}, + 'in_channel': 32, + 'out_channel': 64 +} + +cfg_re50 = { + 'name': 'Resnet50', + 'min_sizes': [[16, 32], [64, 128], [256, 512]], + 'steps': [8, 16, 32], + 'variance': [0.1, 0.2], + 'clip': False, + 'loc_weight': 2.0, + 'gpu_train': True, + 'batch_size': 24, + 'ngpu': 4, + 'epoch': 100, + 'decay1': 70, + 'decay2': 90, + 'image_size': 840, + 'pretrain': False, + 'return_layers': {'layer2': 1, 'layer3': 2, 'layer4': 3}, + 'in_channel': 256, + 'out_channel': 256 +} + diff --git a/face_vid2vid/GPEN/retinaface/data/data_augment.py b/face_vid2vid/GPEN/retinaface/data/data_augment.py new file mode 100644 index 0000000000000000000000000000000000000000..c1b52ae19bf8d9ac3fa256b68730ce1b556c6d6e --- /dev/null +++ b/face_vid2vid/GPEN/retinaface/data/data_augment.py @@ -0,0 +1,237 @@ +import cv2 +import numpy as np +import random +from utils.box_utils import matrix_iof + + +def _crop(image, boxes, labels, landm, img_dim): + height, width, _ = image.shape + pad_image_flag = True + + for _ in range(250): + """ + if random.uniform(0, 1) <= 0.2: + scale = 1.0 + else: + scale = random.uniform(0.3, 1.0) + """ + PRE_SCALES = [0.3, 0.45, 0.6, 0.8, 1.0] + scale = random.choice(PRE_SCALES) + short_side = min(width, height) + w = int(scale * short_side) + h = w + + if width == w: + l = 0 + else: + l = random.randrange(width - w) + if height == h: + t = 0 + else: + t = random.randrange(height - h) + roi = np.array((l, t, l + w, t + h)) + + value = matrix_iof(boxes, roi[np.newaxis]) + flag = (value >= 1) + if not flag.any(): + continue + + centers = (boxes[:, :2] + boxes[:, 2:]) / 2 + mask_a = np.logical_and(roi[:2] < centers, centers < roi[2:]).all(axis=1) + boxes_t = boxes[mask_a].copy() + labels_t = labels[mask_a].copy() + landms_t = landm[mask_a].copy() + landms_t = landms_t.reshape([-1, 5, 2]) + + if boxes_t.shape[0] == 0: + continue + + image_t = image[roi[1]:roi[3], roi[0]:roi[2]] + + boxes_t[:, :2] = np.maximum(boxes_t[:, :2], roi[:2]) + boxes_t[:, :2] -= roi[:2] + boxes_t[:, 2:] = np.minimum(boxes_t[:, 2:], roi[2:]) + boxes_t[:, 2:] -= roi[:2] + + # landm + landms_t[:, :, :2] = landms_t[:, :, :2] - roi[:2] + landms_t[:, :, :2] = np.maximum(landms_t[:, :, :2], np.array([0, 0])) + landms_t[:, :, :2] = np.minimum(landms_t[:, :, :2], roi[2:] - roi[:2]) + landms_t = landms_t.reshape([-1, 10]) + + + # make sure that the cropped image contains at least one face > 16 pixel at training image scale + b_w_t = (boxes_t[:, 2] - boxes_t[:, 0] + 1) / w * img_dim + b_h_t = (boxes_t[:, 3] - boxes_t[:, 1] + 1) / h * img_dim + mask_b = np.minimum(b_w_t, b_h_t) > 0.0 + boxes_t = boxes_t[mask_b] + labels_t = labels_t[mask_b] + landms_t = landms_t[mask_b] + + if boxes_t.shape[0] == 0: + continue + + pad_image_flag = False + + return image_t, boxes_t, labels_t, landms_t, pad_image_flag + return image, boxes, labels, landm, pad_image_flag + + +def _distort(image): + + def _convert(image, alpha=1, beta=0): + tmp = image.astype(float) * alpha + beta + tmp[tmp < 0] = 0 + tmp[tmp > 255] = 255 + image[:] = tmp + + image = image.copy() + + if random.randrange(2): + + #brightness distortion + if random.randrange(2): + _convert(image, beta=random.uniform(-32, 32)) + + #contrast distortion + if random.randrange(2): + _convert(image, alpha=random.uniform(0.5, 1.5)) + + image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + + #saturation distortion + if random.randrange(2): + _convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5)) + + #hue distortion + if random.randrange(2): + tmp = image[:, :, 0].astype(int) + random.randint(-18, 18) + tmp %= 180 + image[:, :, 0] = tmp + + image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) + + else: + + #brightness distortion + if random.randrange(2): + _convert(image, beta=random.uniform(-32, 32)) + + image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + + #saturation distortion + if random.randrange(2): + _convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5)) + + #hue distortion + if random.randrange(2): + tmp = image[:, :, 0].astype(int) + random.randint(-18, 18) + tmp %= 180 + image[:, :, 0] = tmp + + image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) + + #contrast distortion + if random.randrange(2): + _convert(image, alpha=random.uniform(0.5, 1.5)) + + return image + + +def _expand(image, boxes, fill, p): + if random.randrange(2): + return image, boxes + + height, width, depth = image.shape + + scale = random.uniform(1, p) + w = int(scale * width) + h = int(scale * height) + + left = random.randint(0, w - width) + top = random.randint(0, h - height) + + boxes_t = boxes.copy() + boxes_t[:, :2] += (left, top) + boxes_t[:, 2:] += (left, top) + expand_image = np.empty( + (h, w, depth), + dtype=image.dtype) + expand_image[:, :] = fill + expand_image[top:top + height, left:left + width] = image + image = expand_image + + return image, boxes_t + + +def _mirror(image, boxes, landms): + _, width, _ = image.shape + if random.randrange(2): + image = image[:, ::-1] + boxes = boxes.copy() + boxes[:, 0::2] = width - boxes[:, 2::-2] + + # landm + landms = landms.copy() + landms = landms.reshape([-1, 5, 2]) + landms[:, :, 0] = width - landms[:, :, 0] + tmp = landms[:, 1, :].copy() + landms[:, 1, :] = landms[:, 0, :] + landms[:, 0, :] = tmp + tmp1 = landms[:, 4, :].copy() + landms[:, 4, :] = landms[:, 3, :] + landms[:, 3, :] = tmp1 + landms = landms.reshape([-1, 10]) + + return image, boxes, landms + + +def _pad_to_square(image, rgb_mean, pad_image_flag): + if not pad_image_flag: + return image + height, width, _ = image.shape + long_side = max(width, height) + image_t = np.empty((long_side, long_side, 3), dtype=image.dtype) + image_t[:, :] = rgb_mean + image_t[0:0 + height, 0:0 + width] = image + return image_t + + +def _resize_subtract_mean(image, insize, rgb_mean): + interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4] + interp_method = interp_methods[random.randrange(5)] + image = cv2.resize(image, (insize, insize), interpolation=interp_method) + image = image.astype(np.float32) + image -= rgb_mean + return image.transpose(2, 0, 1) + + +class preproc(object): + + def __init__(self, img_dim, rgb_means): + self.img_dim = img_dim + self.rgb_means = rgb_means + + def __call__(self, image, targets): + assert targets.shape[0] > 0, "this image does not have gt" + + boxes = targets[:, :4].copy() + labels = targets[:, -1].copy() + landm = targets[:, 4:-1].copy() + + image_t, boxes_t, labels_t, landm_t, pad_image_flag = _crop(image, boxes, labels, landm, self.img_dim) + image_t = _distort(image_t) + image_t = _pad_to_square(image_t,self.rgb_means, pad_image_flag) + image_t, boxes_t, landm_t = _mirror(image_t, boxes_t, landm_t) + height, width, _ = image_t.shape + image_t = _resize_subtract_mean(image_t, self.img_dim, self.rgb_means) + boxes_t[:, 0::2] /= width + boxes_t[:, 1::2] /= height + + landm_t[:, 0::2] /= width + landm_t[:, 1::2] /= height + + labels_t = np.expand_dims(labels_t, 1) + targets_t = np.hstack((boxes_t, landm_t, labels_t)) + + return image_t, targets_t diff --git a/face_vid2vid/GPEN/retinaface/data/wider_face.py b/face_vid2vid/GPEN/retinaface/data/wider_face.py new file mode 100644 index 0000000000000000000000000000000000000000..22f56efdc221bd4162d22884669ba44a3d4de5cd --- /dev/null +++ b/face_vid2vid/GPEN/retinaface/data/wider_face.py @@ -0,0 +1,101 @@ +import os +import os.path +import sys +import torch +import torch.utils.data as data +import cv2 +import numpy as np + +class WiderFaceDetection(data.Dataset): + def __init__(self, txt_path, preproc=None): + self.preproc = preproc + self.imgs_path = [] + self.words = [] + f = open(txt_path,'r') + lines = f.readlines() + isFirst = True + labels = [] + for line in lines: + line = line.rstrip() + if line.startswith('#'): + if isFirst is True: + isFirst = False + else: + labels_copy = labels.copy() + self.words.append(labels_copy) + labels.clear() + path = line[2:] + path = txt_path.replace('label.txt','images/') + path + self.imgs_path.append(path) + else: + line = line.split(' ') + label = [float(x) for x in line] + labels.append(label) + + self.words.append(labels) + + def __len__(self): + return len(self.imgs_path) + + def __getitem__(self, index): + img = cv2.imread(self.imgs_path[index]) + height, width, _ = img.shape + + labels = self.words[index] + annotations = np.zeros((0, 15)) + if len(labels) == 0: + return annotations + for idx, label in enumerate(labels): + annotation = np.zeros((1, 15)) + # bbox + annotation[0, 0] = label[0] # x1 + annotation[0, 1] = label[1] # y1 + annotation[0, 2] = label[0] + label[2] # x2 + annotation[0, 3] = label[1] + label[3] # y2 + + # landmarks + annotation[0, 4] = label[4] # l0_x + annotation[0, 5] = label[5] # l0_y + annotation[0, 6] = label[7] # l1_x + annotation[0, 7] = label[8] # l1_y + annotation[0, 8] = label[10] # l2_x + annotation[0, 9] = label[11] # l2_y + annotation[0, 10] = label[13] # l3_x + annotation[0, 11] = label[14] # l3_y + annotation[0, 12] = label[16] # l4_x + annotation[0, 13] = label[17] # l4_y + if (annotation[0, 4]<0): + annotation[0, 14] = -1 + else: + annotation[0, 14] = 1 + + annotations = np.append(annotations, annotation, axis=0) + target = np.array(annotations) + if self.preproc is not None: + img, target = self.preproc(img, target) + + return torch.from_numpy(img), target + +def detection_collate(batch): + """Custom collate fn for dealing with batches of images that have a different + number of associated object annotations (bounding boxes). + + Arguments: + batch: (tuple) A tuple of tensor images and lists of annotations + + Return: + A tuple containing: + 1) (tensor) batch of images stacked on their 0 dim + 2) (list of tensors) annotations for a given image are stacked on 0 dim + """ + targets = [] + imgs = [] + for _, sample in enumerate(batch): + for _, tup in enumerate(sample): + if torch.is_tensor(tup): + imgs.append(tup) + elif isinstance(tup, type(np.empty(0))): + annos = torch.from_numpy(tup).float() + targets.append(annos) + + return (torch.stack(imgs, 0), targets) diff --git a/face_vid2vid/GPEN/retinaface/facemodels/__init__.py b/face_vid2vid/GPEN/retinaface/facemodels/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/face_vid2vid/GPEN/retinaface/facemodels/net.py b/face_vid2vid/GPEN/retinaface/facemodels/net.py new file mode 100644 index 0000000000000000000000000000000000000000..beb6040b24258f8b96020c1c9fc2610819718017 --- /dev/null +++ b/face_vid2vid/GPEN/retinaface/facemodels/net.py @@ -0,0 +1,137 @@ +import time +import torch +import torch.nn as nn +import torchvision.models._utils as _utils +import torchvision.models as models +import torch.nn.functional as F +from torch.autograd import Variable + +def conv_bn(inp, oup, stride = 1, leaky = 0): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True) + ) + +def conv_bn_no_relu(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + ) + +def conv_bn1X1(inp, oup, stride, leaky=0): + return nn.Sequential( + nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), + nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True) + ) + +def conv_dw(inp, oup, stride, leaky=0.1): + return nn.Sequential( + nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), + nn.BatchNorm2d(inp), + nn.LeakyReLU(negative_slope= leaky,inplace=True), + + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope= leaky,inplace=True), + ) + +class SSH(nn.Module): + def __init__(self, in_channel, out_channel): + super(SSH, self).__init__() + assert out_channel % 4 == 0 + leaky = 0 + if (out_channel <= 64): + leaky = 0.1 + self.conv3X3 = conv_bn_no_relu(in_channel, out_channel//2, stride=1) + + self.conv5X5_1 = conv_bn(in_channel, out_channel//4, stride=1, leaky = leaky) + self.conv5X5_2 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1) + + self.conv7X7_2 = conv_bn(out_channel//4, out_channel//4, stride=1, leaky = leaky) + self.conv7x7_3 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1) + + def forward(self, input): + conv3X3 = self.conv3X3(input) + + conv5X5_1 = self.conv5X5_1(input) + conv5X5 = self.conv5X5_2(conv5X5_1) + + conv7X7_2 = self.conv7X7_2(conv5X5_1) + conv7X7 = self.conv7x7_3(conv7X7_2) + + out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) + out = F.relu(out) + return out + +class FPN(nn.Module): + def __init__(self,in_channels_list,out_channels): + super(FPN,self).__init__() + leaky = 0 + if (out_channels <= 64): + leaky = 0.1 + self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride = 1, leaky = leaky) + self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride = 1, leaky = leaky) + self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride = 1, leaky = leaky) + + self.merge1 = conv_bn(out_channels, out_channels, leaky = leaky) + self.merge2 = conv_bn(out_channels, out_channels, leaky = leaky) + + def forward(self, input): + # names = list(input.keys()) + input = list(input.values()) + + output1 = self.output1(input[0]) + output2 = self.output2(input[1]) + output3 = self.output3(input[2]) + + up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest") + output2 = output2 + up3 + output2 = self.merge2(output2) + + up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode="nearest") + output1 = output1 + up2 + output1 = self.merge1(output1) + + out = [output1, output2, output3] + return out + + + +class MobileNetV1(nn.Module): + def __init__(self): + super(MobileNetV1, self).__init__() + self.stage1 = nn.Sequential( + conv_bn(3, 8, 2, leaky = 0.1), # 3 + conv_dw(8, 16, 1), # 7 + conv_dw(16, 32, 2), # 11 + conv_dw(32, 32, 1), # 19 + conv_dw(32, 64, 2), # 27 + conv_dw(64, 64, 1), # 43 + ) + self.stage2 = nn.Sequential( + conv_dw(64, 128, 2), # 43 + 16 = 59 + conv_dw(128, 128, 1), # 59 + 32 = 91 + conv_dw(128, 128, 1), # 91 + 32 = 123 + conv_dw(128, 128, 1), # 123 + 32 = 155 + conv_dw(128, 128, 1), # 155 + 32 = 187 + conv_dw(128, 128, 1), # 187 + 32 = 219 + ) + self.stage3 = nn.Sequential( + conv_dw(128, 256, 2), # 219 +3 2 = 241 + conv_dw(256, 256, 1), # 241 + 64 = 301 + ) + self.avg = nn.AdaptiveAvgPool2d((1,1)) + self.fc = nn.Linear(256, 1000) + + def forward(self, x): + x = self.stage1(x) + x = self.stage2(x) + x = self.stage3(x) + x = self.avg(x) + # x = self.model(x) + x = x.view(-1, 256) + x = self.fc(x) + return x + diff --git a/face_vid2vid/GPEN/retinaface/facemodels/retinaface.py b/face_vid2vid/GPEN/retinaface/facemodels/retinaface.py new file mode 100644 index 0000000000000000000000000000000000000000..b7092a2bc2f35d06ce99d25473bce913ef3fd8e7 --- /dev/null +++ b/face_vid2vid/GPEN/retinaface/facemodels/retinaface.py @@ -0,0 +1,127 @@ +import torch +import torch.nn as nn +import torchvision.models.detection.backbone_utils as backbone_utils +import torchvision.models._utils as _utils +import torch.nn.functional as F +from collections import OrderedDict + +from facemodels.net import MobileNetV1 as MobileNetV1 +from facemodels.net import FPN as FPN +from facemodels.net import SSH as SSH + + + +class ClassHead(nn.Module): + def __init__(self,inchannels=512,num_anchors=3): + super(ClassHead,self).__init__() + self.num_anchors = num_anchors + self.conv1x1 = nn.Conv2d(inchannels,self.num_anchors*2,kernel_size=(1,1),stride=1,padding=0) + + def forward(self,x): + out = self.conv1x1(x) + out = out.permute(0,2,3,1).contiguous() + + return out.view(out.shape[0], -1, 2) + +class BboxHead(nn.Module): + def __init__(self,inchannels=512,num_anchors=3): + super(BboxHead,self).__init__() + self.conv1x1 = nn.Conv2d(inchannels,num_anchors*4,kernel_size=(1,1),stride=1,padding=0) + + def forward(self,x): + out = self.conv1x1(x) + out = out.permute(0,2,3,1).contiguous() + + return out.view(out.shape[0], -1, 4) + +class LandmarkHead(nn.Module): + def __init__(self,inchannels=512,num_anchors=3): + super(LandmarkHead,self).__init__() + self.conv1x1 = nn.Conv2d(inchannels,num_anchors*10,kernel_size=(1,1),stride=1,padding=0) + + def forward(self,x): + out = self.conv1x1(x) + out = out.permute(0,2,3,1).contiguous() + + return out.view(out.shape[0], -1, 10) + +class RetinaFace(nn.Module): + def __init__(self, cfg = None, phase = 'train'): + """ + :param cfg: Network related settings. + :param phase: train or test. + """ + super(RetinaFace,self).__init__() + self.phase = phase + backbone = None + if cfg['name'] == 'mobilenet0.25': + backbone = MobileNetV1() + if cfg['pretrain']: + checkpoint = torch.load("./weights/mobilenetV1X0.25_pretrain.tar", map_location=torch.device('cpu')) + from collections import OrderedDict + new_state_dict = OrderedDict() + for k, v in checkpoint['state_dict'].items(): + name = k[7:] # remove module. + new_state_dict[name] = v + # load params + backbone.load_state_dict(new_state_dict) + elif cfg['name'] == 'Resnet50': + import torchvision.models as models + backbone = models.resnet50(pretrained=cfg['pretrain']) + + self.body = _utils.IntermediateLayerGetter(backbone, cfg['return_layers']) + in_channels_stage2 = cfg['in_channel'] + in_channels_list = [ + in_channels_stage2 * 2, + in_channels_stage2 * 4, + in_channels_stage2 * 8, + ] + out_channels = cfg['out_channel'] + self.fpn = FPN(in_channels_list,out_channels) + self.ssh1 = SSH(out_channels, out_channels) + self.ssh2 = SSH(out_channels, out_channels) + self.ssh3 = SSH(out_channels, out_channels) + + self.ClassHead = self._make_class_head(fpn_num=3, inchannels=cfg['out_channel']) + self.BboxHead = self._make_bbox_head(fpn_num=3, inchannels=cfg['out_channel']) + self.LandmarkHead = self._make_landmark_head(fpn_num=3, inchannels=cfg['out_channel']) + + def _make_class_head(self,fpn_num=3,inchannels=64,anchor_num=2): + classhead = nn.ModuleList() + for i in range(fpn_num): + classhead.append(ClassHead(inchannels,anchor_num)) + return classhead + + def _make_bbox_head(self,fpn_num=3,inchannels=64,anchor_num=2): + bboxhead = nn.ModuleList() + for i in range(fpn_num): + bboxhead.append(BboxHead(inchannels,anchor_num)) + return bboxhead + + def _make_landmark_head(self,fpn_num=3,inchannels=64,anchor_num=2): + landmarkhead = nn.ModuleList() + for i in range(fpn_num): + landmarkhead.append(LandmarkHead(inchannels,anchor_num)) + return landmarkhead + + def forward(self,inputs): + out = self.body(inputs) + + # FPN + fpn = self.fpn(out) + + # SSH + feature1 = self.ssh1(fpn[0]) + feature2 = self.ssh2(fpn[1]) + feature3 = self.ssh3(fpn[2]) + features = [feature1, feature2, feature3] + + bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1) + classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)],dim=1) + ldm_regressions = torch.cat([self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1) + + if self.phase == 'train': + output = (bbox_regressions, classifications, ldm_regressions) + else: + output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions) + return output \ No newline at end of file diff --git a/face_vid2vid/GPEN/retinaface/layers/__init__.py b/face_vid2vid/GPEN/retinaface/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..53a3f4b5160995d93bc7911e808b3045d74362c9 --- /dev/null +++ b/face_vid2vid/GPEN/retinaface/layers/__init__.py @@ -0,0 +1,2 @@ +from .functions import * +from .modules import * diff --git a/face_vid2vid/GPEN/retinaface/layers/functions/prior_box.py b/face_vid2vid/GPEN/retinaface/layers/functions/prior_box.py new file mode 100644 index 0000000000000000000000000000000000000000..80c7f858371ed71f39ed609eb44b423d8693bf61 --- /dev/null +++ b/face_vid2vid/GPEN/retinaface/layers/functions/prior_box.py @@ -0,0 +1,34 @@ +import torch +from itertools import product as product +import numpy as np +from math import ceil + + +class PriorBox(object): + def __init__(self, cfg, image_size=None, phase='train'): + super(PriorBox, self).__init__() + self.min_sizes = cfg['min_sizes'] + self.steps = cfg['steps'] + self.clip = cfg['clip'] + self.image_size = image_size + self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps] + self.name = "s" + + def forward(self): + anchors = [] + for k, f in enumerate(self.feature_maps): + min_sizes = self.min_sizes[k] + for i, j in product(range(f[0]), range(f[1])): + for min_size in min_sizes: + s_kx = min_size / self.image_size[1] + s_ky = min_size / self.image_size[0] + dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]] + dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]] + for cy, cx in product(dense_cy, dense_cx): + anchors += [cx, cy, s_kx, s_ky] + + # back to torch land + output = torch.Tensor(anchors).view(-1, 4) + if self.clip: + output.clamp_(max=1, min=0) + return output diff --git a/face_vid2vid/GPEN/retinaface/layers/modules/__init__.py b/face_vid2vid/GPEN/retinaface/layers/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cf24bddbf283f233d0b93fc074a2bac2f5c044a9 --- /dev/null +++ b/face_vid2vid/GPEN/retinaface/layers/modules/__init__.py @@ -0,0 +1,3 @@ +from .multibox_loss import MultiBoxLoss + +__all__ = ['MultiBoxLoss'] diff --git a/face_vid2vid/GPEN/retinaface/layers/modules/multibox_loss.py b/face_vid2vid/GPEN/retinaface/layers/modules/multibox_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..096620480eba59e9d893c1940899f7e3d6736cae --- /dev/null +++ b/face_vid2vid/GPEN/retinaface/layers/modules/multibox_loss.py @@ -0,0 +1,125 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from utils.box_utils import match, log_sum_exp +from data import cfg_mnet +GPU = cfg_mnet['gpu_train'] + +class MultiBoxLoss(nn.Module): + """SSD Weighted Loss Function + Compute Targets: + 1) Produce Confidence Target Indices by matching ground truth boxes + with (default) 'priorboxes' that have jaccard index > threshold parameter + (default threshold: 0.5). + 2) Produce localization target by 'encoding' variance into offsets of ground + truth boxes and their matched 'priorboxes'. + 3) Hard negative mining to filter the excessive number of negative examples + that comes with using a large number of default bounding boxes. + (default negative:positive ratio 3:1) + Objective Loss: + L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N + Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss + weighted by α which is set to 1 by cross val. + Args: + c: class confidences, + l: predicted boxes, + g: ground truth boxes + N: number of matched default boxes + See: https://arxiv.org/pdf/1512.02325.pdf for more details. + """ + + def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target): + super(MultiBoxLoss, self).__init__() + self.num_classes = num_classes + self.threshold = overlap_thresh + self.background_label = bkg_label + self.encode_target = encode_target + self.use_prior_for_matching = prior_for_matching + self.do_neg_mining = neg_mining + self.negpos_ratio = neg_pos + self.neg_overlap = neg_overlap + self.variance = [0.1, 0.2] + + def forward(self, predictions, priors, targets): + """Multibox Loss + Args: + predictions (tuple): A tuple containing loc preds, conf preds, + and prior boxes from SSD net. + conf shape: torch.size(batch_size,num_priors,num_classes) + loc shape: torch.size(batch_size,num_priors,4) + priors shape: torch.size(num_priors,4) + + ground_truth (tensor): Ground truth boxes and labels for a batch, + shape: [batch_size,num_objs,5] (last idx is the label). + """ + + loc_data, conf_data, landm_data = predictions + priors = priors + num = loc_data.size(0) + num_priors = (priors.size(0)) + + # match priors (default boxes) and ground truth boxes + loc_t = torch.Tensor(num, num_priors, 4) + landm_t = torch.Tensor(num, num_priors, 10) + conf_t = torch.LongTensor(num, num_priors) + for idx in range(num): + truths = targets[idx][:, :4].data + labels = targets[idx][:, -1].data + landms = targets[idx][:, 4:14].data + defaults = priors.data + match(self.threshold, truths, defaults, self.variance, labels, landms, loc_t, conf_t, landm_t, idx) + if GPU: + loc_t = loc_t.cuda() + conf_t = conf_t.cuda() + landm_t = landm_t.cuda() + + zeros = torch.tensor(0).cuda() + # landm Loss (Smooth L1) + # Shape: [batch,num_priors,10] + pos1 = conf_t > zeros + num_pos_landm = pos1.long().sum(1, keepdim=True) + N1 = max(num_pos_landm.data.sum().float(), 1) + pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data) + landm_p = landm_data[pos_idx1].view(-1, 10) + landm_t = landm_t[pos_idx1].view(-1, 10) + loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum') + + + pos = conf_t != zeros + conf_t[pos] = 1 + + # Localization Loss (Smooth L1) + # Shape: [batch,num_priors,4] + pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) + loc_p = loc_data[pos_idx].view(-1, 4) + loc_t = loc_t[pos_idx].view(-1, 4) + loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum') + + # Compute max conf across batch for hard negative mining + batch_conf = conf_data.view(-1, self.num_classes) + loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1)) + + # Hard Negative Mining + loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now + loss_c = loss_c.view(num, -1) + _, loss_idx = loss_c.sort(1, descending=True) + _, idx_rank = loss_idx.sort(1) + num_pos = pos.long().sum(1, keepdim=True) + num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1) + neg = idx_rank < num_neg.expand_as(idx_rank) + + # Confidence Loss Including Positive and Negative Examples + pos_idx = pos.unsqueeze(2).expand_as(conf_data) + neg_idx = neg.unsqueeze(2).expand_as(conf_data) + conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes) + targets_weighted = conf_t[(pos+neg).gt(0)] + loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum') + + # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N + N = max(num_pos.data.sum().float(), 1) + loss_l /= N + loss_c /= N + loss_landm /= N1 + + return loss_l, loss_c, loss_landm diff --git a/face_vid2vid/GPEN/retinaface/retinaface_detection.py b/face_vid2vid/GPEN/retinaface/retinaface_detection.py new file mode 100644 index 0000000000000000000000000000000000000000..b15d7cf53b0096b8843b4b4ee46be3f65d25302c --- /dev/null +++ b/face_vid2vid/GPEN/retinaface/retinaface_detection.py @@ -0,0 +1,200 @@ +''' +@paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021) +@author: yangxy (yangtao9009@gmail.com) +''' +import os +import torch +import torch.backends.cudnn as cudnn +import numpy as np +from data import cfg_re50 +from layers.functions.prior_box import PriorBox +from utils.nms.py_cpu_nms import py_cpu_nms +import cv2 +from facemodels.retinaface import RetinaFace +from utils.box_utils import decode, decode_landm +import time +import torch + +class RetinaFaceDetection(object): + def __init__(self, base_dir, network='RetinaFace-R50'): + torch.set_grad_enabled(False) + cudnn.benchmark = True + self.pretrained_path = os.path.join(base_dir, 'weights', network+'.pth') + self.device = torch.cuda.current_device() + self.cfg = cfg_re50 + self.net = RetinaFace(cfg=self.cfg, phase='test') + self.load_model() + self.net = self.net.cuda() + self.net_trt = None + + def check_keys(self, pretrained_state_dict): + ckpt_keys = set(pretrained_state_dict.keys()) + model_keys = set(self.net.state_dict().keys()) + used_pretrained_keys = model_keys & ckpt_keys + unused_pretrained_keys = ckpt_keys - model_keys + missing_keys = model_keys - ckpt_keys + assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint' + return True + + def remove_prefix(self, state_dict, prefix): + ''' Old style model is stored with all names of parameters sharing common prefix 'module.' ''' + f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x + return {f(key): value for key, value in state_dict.items()} + + def load_model(self, load_to_cpu=False): + if load_to_cpu: + pretrained_dict = torch.load(self.pretrained_path, map_location=lambda storage, loc: storage) + else: + pretrained_dict = torch.load(self.pretrained_path, map_location=lambda storage, loc: storage.cuda()) + if "state_dict" in pretrained_dict.keys(): + pretrained_dict = self.remove_prefix(pretrained_dict['state_dict'], 'module.') + else: + pretrained_dict = self.remove_prefix(pretrained_dict, 'module.') + self.check_keys(pretrained_dict) + self.net.load_state_dict(pretrained_dict, strict=False) + self.net.eval() + + def build_trt(self, img_raw): + img = np.float32(img_raw) + + img -= (104, 117, 123) + img = img.transpose(2, 0, 1) + img = torch.from_numpy(img).unsqueeze(0) + img = img.cuda() + + print('building trt model FaceGAN') + from torch2trt import torch2trt + self.net_trt = torch2trt(self.net, [img], fp16_mode=True) + del self.net + print('sucessfully built') + + def detect_trt(self, img_raw, resize=1, confidence_threshold=0.9, nms_threshold=0.4, top_k=5000, keep_top_k=750, save_image=False): + img = np.float32(img_raw) + + im_height, im_width = img.shape[:2] + scale = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]]) + img -= (104, 117, 123) + img = img.transpose(2, 0, 1) + img = torch.from_numpy(img).unsqueeze(0) + img = img.cuda() + scale = scale.cuda() + + loc, conf, landms = self.net_trt(img) # forward pass + + priorbox = PriorBox(self.cfg, image_size=(im_height, im_width)) + priors = priorbox.forward() + priors = priors.cuda() + prior_data = priors.data + boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance']) + boxes = boxes * scale / resize + boxes = boxes.cpu().numpy() + scores = conf.squeeze(0).data.cpu().numpy()[:, 1] + landms = decode_landm(landms.data.squeeze(0), prior_data, self.cfg['variance']) + scale1 = torch.Tensor([img.shape[3], img.shape[2], img.shape[3], img.shape[2], + img.shape[3], img.shape[2], img.shape[3], img.shape[2], + img.shape[3], img.shape[2]]) + scale1 = scale1.cuda() + landms = landms * scale1 / resize + landms = landms.cpu().numpy() + + # ignore low scores + inds = np.where(scores > confidence_threshold)[0] + boxes = boxes[inds] + landms = landms[inds] + scores = scores[inds] + + # keep top-K before NMS + order = scores.argsort()[::-1][:top_k] + boxes = boxes[order] + landms = landms[order] + scores = scores[order] + + # do NMS + dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) + keep = py_cpu_nms(dets, nms_threshold) + # keep = nms(dets, nms_threshold,force_cpu=args.cpu) + dets = dets[keep, :] + landms = landms[keep] + + # keep top-K faster NMS + dets = dets[:keep_top_k, :] + landms = landms[:keep_top_k, :] + + # sort faces(delete) + ''' + fscores = [det[4] for det in dets] + sorted_idx = sorted(range(len(fscores)), key=lambda k:fscores[k], reverse=False) # sort index + tmp = [landms[idx] for idx in sorted_idx] + landms = np.asarray(tmp) + ''' + + landms = landms.reshape((-1, 5, 2)) + landms = landms.transpose((0, 2, 1)) + landms = landms.reshape(-1, 10, ) + return dets, landms + + + def detect(self, img_raw, resize=1, confidence_threshold=0.9, nms_threshold=0.4, top_k=5000, keep_top_k=750, save_image=False): + img = np.float32(img_raw) + + im_height, im_width = img.shape[:2] + scale = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]]) + img -= (104, 117, 123) + img = img.transpose(2, 0, 1) + img = torch.from_numpy(img).unsqueeze(0) + img = img.cuda() + scale = scale.cuda() + + loc, conf, landms = self.net(img) # forward pass + + priorbox = PriorBox(self.cfg, image_size=(im_height, im_width)) + priors = priorbox.forward() + priors = priors.cuda() + prior_data = priors.data + boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance']) + boxes = boxes * scale / resize + boxes = boxes.cpu().numpy() + scores = conf.squeeze(0).data.cpu().numpy()[:, 1] + landms = decode_landm(landms.data.squeeze(0), prior_data, self.cfg['variance']) + scale1 = torch.Tensor([img.shape[3], img.shape[2], img.shape[3], img.shape[2], + img.shape[3], img.shape[2], img.shape[3], img.shape[2], + img.shape[3], img.shape[2]]) + scale1 = scale1.cuda() + landms = landms * scale1 / resize + landms = landms.cpu().numpy() + + # ignore low scores + inds = np.where(scores > confidence_threshold)[0] + boxes = boxes[inds] + landms = landms[inds] + scores = scores[inds] + + # keep top-K before NMS + order = scores.argsort()[::-1][:top_k] + boxes = boxes[order] + landms = landms[order] + scores = scores[order] + + # do NMS + dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) + keep = py_cpu_nms(dets, nms_threshold) + # keep = nms(dets, nms_threshold,force_cpu=args.cpu) + dets = dets[keep, :] + landms = landms[keep] + + # keep top-K faster NMS + dets = dets[:keep_top_k, :] + landms = landms[:keep_top_k, :] + + # sort faces(delete) + ''' + fscores = [det[4] for det in dets] + sorted_idx = sorted(range(len(fscores)), key=lambda k:fscores[k], reverse=False) # sort index + tmp = [landms[idx] for idx in sorted_idx] + landms = np.asarray(tmp) + ''' + + landms = landms.reshape((-1, 5, 2)) + landms = landms.transpose((0, 2, 1)) + landms = landms.reshape(-1, 10, ) + return dets, landms diff --git a/face_vid2vid/GPEN/retinaface/utils/__init__.py b/face_vid2vid/GPEN/retinaface/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/face_vid2vid/GPEN/retinaface/utils/box_utils.py b/face_vid2vid/GPEN/retinaface/utils/box_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c1d12bc612ae3ba3ea9d138bfc5997a2b15d8dd9 --- /dev/null +++ b/face_vid2vid/GPEN/retinaface/utils/box_utils.py @@ -0,0 +1,330 @@ +import torch +import numpy as np + + +def point_form(boxes): + """ Convert prior_boxes to (xmin, ymin, xmax, ymax) + representation for comparison to point form ground truth data. + Args: + boxes: (tensor) center-size default boxes from priorbox layers. + Return: + boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. + """ + return torch.cat((boxes[:, :2] - boxes[:, 2:]/2, # xmin, ymin + boxes[:, :2] + boxes[:, 2:]/2), 1) # xmax, ymax + + +def center_size(boxes): + """ Convert prior_boxes to (cx, cy, w, h) + representation for comparison to center-size form ground truth data. + Args: + boxes: (tensor) point_form boxes + Return: + boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. + """ + return torch.cat((boxes[:, 2:] + boxes[:, :2])/2, # cx, cy + boxes[:, 2:] - boxes[:, :2], 1) # w, h + + +def intersect(box_a, box_b): + """ We resize both tensors to [A,B,2] without new malloc: + [A,2] -> [A,1,2] -> [A,B,2] + [B,2] -> [1,B,2] -> [A,B,2] + Then we compute the area of intersect between box_a and box_b. + Args: + box_a: (tensor) bounding boxes, Shape: [A,4]. + box_b: (tensor) bounding boxes, Shape: [B,4]. + Return: + (tensor) intersection area, Shape: [A,B]. + """ + A = box_a.size(0) + B = box_b.size(0) + max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), + box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) + min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), + box_b[:, :2].unsqueeze(0).expand(A, B, 2)) + inter = torch.clamp((max_xy - min_xy), min=0) + return inter[:, :, 0] * inter[:, :, 1] + + +def jaccard(box_a, box_b): + """Compute the jaccard overlap of two sets of boxes. The jaccard overlap + is simply the intersection over union of two boxes. Here we operate on + ground truth boxes and default boxes. + E.g.: + A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) + Args: + box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] + box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] + Return: + jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] + """ + inter = intersect(box_a, box_b) + area_a = ((box_a[:, 2]-box_a[:, 0]) * + (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] + area_b = ((box_b[:, 2]-box_b[:, 0]) * + (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] + union = area_a + area_b - inter + return inter / union # [A,B] + + +def matrix_iou(a, b): + """ + return iou of a and b, numpy version for data augenmentation + """ + lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) + return area_i / (area_a[:, np.newaxis] + area_b - area_i) + + +def matrix_iof(a, b): + """ + return iof of a and b, numpy version for data augenmentation + """ + lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + return area_i / np.maximum(area_a[:, np.newaxis], 1) + + +def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx): + """Match each prior box with the ground truth box of the highest jaccard + overlap, encode the bounding boxes, then return the matched indices + corresponding to both confidence and location preds. + Args: + threshold: (float) The overlap threshold used when mathing boxes. + truths: (tensor) Ground truth boxes, Shape: [num_obj, 4]. + priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4]. + variances: (tensor) Variances corresponding to each prior coord, + Shape: [num_priors, 4]. + labels: (tensor) All the class labels for the image, Shape: [num_obj]. + landms: (tensor) Ground truth landms, Shape [num_obj, 10]. + loc_t: (tensor) Tensor to be filled w/ endcoded location targets. + conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds. + landm_t: (tensor) Tensor to be filled w/ endcoded landm targets. + idx: (int) current batch index + Return: + The matched indices corresponding to 1)location 2)confidence 3)landm preds. + """ + # jaccard index + overlaps = jaccard( + truths, + point_form(priors) + ) + # (Bipartite Matching) + # [1,num_objects] best prior for each ground truth + best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) + + # ignore hard gt + valid_gt_idx = best_prior_overlap[:, 0] >= 0.2 + best_prior_idx_filter = best_prior_idx[valid_gt_idx, :] + if best_prior_idx_filter.shape[0] <= 0: + loc_t[idx] = 0 + conf_t[idx] = 0 + return + + # [1,num_priors] best ground truth for each prior + best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) + best_truth_idx.squeeze_(0) + best_truth_overlap.squeeze_(0) + best_prior_idx.squeeze_(1) + best_prior_idx_filter.squeeze_(1) + best_prior_overlap.squeeze_(1) + best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior + # TODO refactor: index best_prior_idx with long tensor + # ensure every gt matches with its prior of max overlap + for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes + best_truth_idx[best_prior_idx[j]] = j + matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来 + conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来 + conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本 + loc = encode(matches, priors, variances) + + matches_landm = landms[best_truth_idx] + landm = encode_landm(matches_landm, priors, variances) + loc_t[idx] = loc # [num_priors,4] encoded offsets to learn + conf_t[idx] = conf # [num_priors] top class label for each prior + landm_t[idx] = landm + + +def encode(matched, priors, variances): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + Args: + matched: (tensor) Coords of ground truth for each prior in point-form + Shape: [num_priors, 4]. + priors: (tensor) Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + encoded boxes (tensor), Shape: [num_priors, 4] + """ + + # dist b/t match center and prior's center + g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2] + # encode variance + g_cxcy /= (variances[0] * priors[:, 2:]) + # match wh / prior wh + g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] + g_wh = torch.log(g_wh) / variances[1] + # return target for smooth_l1_loss + return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] + +def encode_landm(matched, priors, variances): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + Args: + matched: (tensor) Coords of ground truth for each prior in point-form + Shape: [num_priors, 10]. + priors: (tensor) Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + encoded landm (tensor), Shape: [num_priors, 10] + """ + + # dist b/t match center and prior's center + matched = torch.reshape(matched, (matched.size(0), 5, 2)) + priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2) + g_cxcy = matched[:, :, :2] - priors[:, :, :2] + # encode variance + g_cxcy /= (variances[0] * priors[:, :, 2:]) + # g_cxcy /= priors[:, :, 2:] + g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1) + # return target for smooth_l1_loss + return g_cxcy + + +# Adapted from https://github.com/Hakuyume/chainer-ssd +def decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat(( + priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + +def decode_landm(pre, priors, variances): + """Decode landm from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + pre (tensor): landm predictions for loc layers, + Shape: [num_priors,10] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded landm predictions + """ + landms = torch.cat((priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:], + ), dim=1) + return landms + + +def log_sum_exp(x): + """Utility function for computing log_sum_exp while determining + This will be used to determine unaveraged confidence loss across + all examples in a batch. + Args: + x (Variable(tensor)): conf_preds from conf layers + """ + x_max = x.data.max() + return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max + + +# Original author: Francisco Massa: +# https://github.com/fmassa/object-detection.torch +# Ported to PyTorch by Max deGroot (02/01/2017) +def nms(boxes, scores, overlap=0.5, top_k=200): + """Apply non-maximum suppression at test time to avoid detecting too many + overlapping bounding boxes for a given object. + Args: + boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. + scores: (tensor) The class predscores for the img, Shape:[num_priors]. + overlap: (float) The overlap thresh for suppressing unnecessary boxes. + top_k: (int) The Maximum number of box preds to consider. + Return: + The indices of the kept boxes with respect to num_priors. + """ + + keep = torch.Tensor(scores.size(0)).fill_(0).long() + if boxes.numel() == 0: + return keep + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + area = torch.mul(x2 - x1, y2 - y1) + v, idx = scores.sort(0) # sort in ascending order + # I = I[v >= 0.01] + idx = idx[-top_k:] # indices of the top-k largest vals + xx1 = boxes.new() + yy1 = boxes.new() + xx2 = boxes.new() + yy2 = boxes.new() + w = boxes.new() + h = boxes.new() + + # keep = torch.Tensor() + count = 0 + while idx.numel() > 0: + i = idx[-1] # index of current largest val + # keep.append(i) + keep[count] = i + count += 1 + if idx.size(0) == 1: + break + idx = idx[:-1] # remove kept element from view + # load bboxes of next highest vals + torch.index_select(x1, 0, idx, out=xx1) + torch.index_select(y1, 0, idx, out=yy1) + torch.index_select(x2, 0, idx, out=xx2) + torch.index_select(y2, 0, idx, out=yy2) + # store element-wise max with next highest score + xx1 = torch.clamp(xx1, min=x1[i]) + yy1 = torch.clamp(yy1, min=y1[i]) + xx2 = torch.clamp(xx2, max=x2[i]) + yy2 = torch.clamp(yy2, max=y2[i]) + w.resize_as_(xx2) + h.resize_as_(yy2) + w = xx2 - xx1 + h = yy2 - yy1 + # check sizes of xx1 and xx2.. after each iteration + w = torch.clamp(w, min=0.0) + h = torch.clamp(h, min=0.0) + inter = w*h + # IoU = i / (area(a) + area(b) - i) + rem_areas = torch.index_select(area, 0, idx) # load remaining areas) + union = (rem_areas - inter) + area[i] + IoU = inter/union # store result in iou + # keep only elements with an IoU <= overlap + idx = idx[IoU.le(overlap)] + return keep, count + + diff --git a/face_vid2vid/GPEN/retinaface/utils/nms/__init__.py b/face_vid2vid/GPEN/retinaface/utils/nms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/face_vid2vid/GPEN/retinaface/utils/nms/py_cpu_nms.py b/face_vid2vid/GPEN/retinaface/utils/nms/py_cpu_nms.py new file mode 100644 index 0000000000000000000000000000000000000000..54e7b25fef72b518df6dcf8d6fb78b986796c6e3 --- /dev/null +++ b/face_vid2vid/GPEN/retinaface/utils/nms/py_cpu_nms.py @@ -0,0 +1,38 @@ +# -------------------------------------------------------- +# Fast R-CNN +# Copyright (c) 2015 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ross Girshick +# -------------------------------------------------------- + +import numpy as np + +def py_cpu_nms(dets, thresh): + """Pure Python NMS baseline.""" + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep diff --git a/face_vid2vid/GPEN/retinaface/utils/timer.py b/face_vid2vid/GPEN/retinaface/utils/timer.py new file mode 100644 index 0000000000000000000000000000000000000000..e4b3b8098a5ad41f8d18d42b6b2fedb694aa5508 --- /dev/null +++ b/face_vid2vid/GPEN/retinaface/utils/timer.py @@ -0,0 +1,40 @@ +# -------------------------------------------------------- +# Fast R-CNN +# Copyright (c) 2015 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ross Girshick +# -------------------------------------------------------- + +import time + + +class Timer(object): + """A simple timer.""" + def __init__(self): + self.total_time = 0. + self.calls = 0 + self.start_time = 0. + self.diff = 0. + self.average_time = 0. + + def tic(self): + # using time.time instead of time.clock because time time.clock + # does not normalize for multithreading + self.start_time = time.time() + + def toc(self, average=True): + self.diff = time.time() - self.start_time + self.total_time += self.diff + self.calls += 1 + self.average_time = self.total_time / self.calls + if average: + return self.average_time + else: + return self.diff + + def clear(self): + self.total_time = 0. + self.calls = 0 + self.start_time = 0. + self.diff = 0. + self.average_time = 0. diff --git a/face_vid2vid/GPEN/sr_model/arch_util.py b/face_vid2vid/GPEN/sr_model/arch_util.py new file mode 100644 index 0000000000000000000000000000000000000000..ce5b9d92f418d3f8b5b8887a24491f65660b33f9 --- /dev/null +++ b/face_vid2vid/GPEN/sr_model/arch_util.py @@ -0,0 +1,125 @@ +import math +import torch +from torch import nn as nn +from torch.nn import functional as F +from torch.nn import init as init +from torch.nn.modules.batchnorm import _BatchNorm + +@torch.no_grad() +def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): + """Initialize network weights. + + Args: + module_list (list[nn.Module] | nn.Module): Modules to be initialized. + scale (float): Scale initialized weights, especially for residual + blocks. Default: 1. + bias_fill (float): The value to fill bias. Default: 0 + kwargs (dict): Other arguments for initialization function. + """ + if not isinstance(module_list, list): + module_list = [module_list] + for module in module_list: + for m in module.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, _BatchNorm): + init.constant_(m.weight, 1) + if m.bias is not None: + m.bias.data.fill_(bias_fill) + + +def make_layer(basic_block, num_basic_block, **kwarg): + """Make layers by stacking the same blocks. + + Args: + basic_block (nn.module): nn.module class for basic block. + num_basic_block (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_basic_block): + layers.append(basic_block(**kwarg)) + return nn.Sequential(*layers) + + +class ResidualBlockNoBN(nn.Module): + """Residual block without BN. + + It has a style of: + ---Conv-ReLU-Conv-+- + |________________| + + Args: + num_feat (int): Channel number of intermediate features. + Default: 64. + res_scale (float): Residual scale. Default: 1. + pytorch_init (bool): If set to True, use pytorch default init, + otherwise, use default_init_weights. Default: False. + """ + + def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): + super(ResidualBlockNoBN, self).__init__() + self.res_scale = res_scale + self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.relu = nn.ReLU(inplace=True) + + if not pytorch_init: + default_init_weights([self.conv1, self.conv2], 0.1) + + def forward(self, x): + identity = x + out = self.conv2(self.relu(self.conv1(x))) + return identity + out * self.res_scale + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' + 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + +# TODO: may write a cpp file +def pixel_unshuffle(x, scale): + """ Pixel unshuffle. + + Args: + x (Tensor): Input feature with shape (b, c, hh, hw). + scale (int): Downsample ratio. + + Returns: + Tensor: the pixel unshuffled feature. + """ + b, c, hh, hw = x.size() + out_channel = c * (scale**2) + assert hh % scale == 0 and hw % scale == 0 + h = hh // scale + w = hw // scale + x_view = x.view(b, c, h, scale, w, scale) + return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) \ No newline at end of file diff --git a/face_vid2vid/GPEN/sr_model/real_esrnet.py b/face_vid2vid/GPEN/sr_model/real_esrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..e2af56f38b7b227eb89d6e7337ad094336e3684e --- /dev/null +++ b/face_vid2vid/GPEN/sr_model/real_esrnet.py @@ -0,0 +1,107 @@ +import os +import torch +import numpy as np +from rrdbnet_arch import RRDBNet +from torch.nn import functional as F +import torch + + +class RealESRNet(object): + def __init__(self, base_dir=os.path.dirname(__file__), model=None, scale=2): + self.base_dir = base_dir + self.scale = scale + self.load_srmodel(base_dir, model) + self.srmodel_trt = None + + def load_srmodel(self, base_dir, model): + self.scale = 2 if "x2" in model else 4 if "x4" in model else -1 + if self.scale == -1: + raise Exception("Scale not supported") + self.srmodel = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=32, num_block=23, num_grow_ch=32, scale=self.scale) + if model is None: + loadnet = torch.load(os.path.join(self.base_dir, 'weights', 'realesrnet_x2.pth')) + else: + loadnet = torch.load(os.path.join(self.base_dir, 'weights', model+'.pth')) + self.srmodel.load_state_dict(loadnet['params_ema'], strict=True) + self.srmodel.eval() + self.srmodel = self.srmodel.cuda() + + def build_trt(self, img): + img = img.astype(np.float32) / 255. + img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() + img = img.unsqueeze(0).cuda() + print('building trt model srmodel') + from torch2trt import torch2trt + self.srmodel_trt = torch2trt(self.srmodel, [img], fp16_mode=True) + print('sucessfully built') + del self.srmodel + + def process_trt(self, img): + img = img.astype(np.float32) / 255. + img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() + img = img.unsqueeze(0).cuda() + + if self.scale == 2: + mod_scale = 2 + elif self.scale == 1: + mod_scale = 4 + else: + mod_scale = None + if mod_scale is not None: + h_pad, w_pad = 0, 0 + _, _, h, w = img.size() + if (h % mod_scale != 0): + h_pad = (mod_scale - h % mod_scale) + if (w % mod_scale != 0): + w_pad = (mod_scale - w % mod_scale) + img = F.pad(img, (0, w_pad, 0, h_pad), 'reflect') + try: + with torch.no_grad(): + output = self.srmodel_trt(img) + # remove extra pad + if mod_scale is not None: + _, _, h, w = output.size() + output = output[:, :, 0:h - h_pad, 0:w - w_pad] + output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) + output = (output * 255.0).round().astype(np.uint8) + + return output + except: + return None + + def process(self, img): + img = img.astype(np.float32) / 255. + img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() + img = img.unsqueeze(0).cuda() + # print(img.shape) + + if self.scale == 2: + mod_scale = 2 + elif self.scale == 1: + mod_scale = 4 + else: + mod_scale = None + if mod_scale is not None: + h_pad, w_pad = 0, 0 + _, _, h, w = img.size() + if (h % mod_scale != 0): + h_pad = (mod_scale - h % mod_scale) + if (w % mod_scale != 0): + w_pad = (mod_scale - w % mod_scale) + img = F.pad(img, (0, w_pad, 0, h_pad), 'reflect') + try: + with torch.no_grad(): + output = self.srmodel(img) + # remove extra pad + if mod_scale is not None: + _, _, h, w = output.size() + output = output[:, :, 0:h - h_pad, 0:w - w_pad] + output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) + output = (output * 255.0).round().astype(np.uint8) + + return output + except: + return None + diff --git a/face_vid2vid/GPEN/sr_model/rrdbnet_arch.py b/face_vid2vid/GPEN/sr_model/rrdbnet_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..5e1f04c5aee5bcdcd2ddae5471843ff057d863b4 --- /dev/null +++ b/face_vid2vid/GPEN/sr_model/rrdbnet_arch.py @@ -0,0 +1,116 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + +from arch_util import default_init_weights, make_layer, pixel_unshuffle + + +class ResidualDenseBlock(nn.Module): + """Residual Dense Block. + + Used in RRDB block in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat=64, num_grow_ch=32): + super(ResidualDenseBlock, self).__init__() + self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) + self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + # Emperically, we use 0.2 to scale the residual for better performance + return x5 * 0.2 + x + + +class RRDB(nn.Module): + """Residual in Residual Dense Block. + + Used in RRDB-Net in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat, num_grow_ch=32): + super(RRDB, self).__init__() + self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) + + def forward(self, x): + out = self.rdb1(x) + out = self.rdb2(out) + out = self.rdb3(out) + # Emperically, we use 0.2 to scale the residual for better performance + return out * 0.2 + x + +class RRDBNet(nn.Module): + """Networks consisting of Residual in Residual Dense Block, which is used + in ESRGAN. + + ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. + + We extend ESRGAN for scale x2 and scale x1. + Note: This is one option for scale 1, scale 2 in RRDBNet. + We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size + and enlarge the channel size before feeding inputs into the main ESRGAN architecture. + + Args: + num_in_ch (int): Channel number of inputs. + num_out_ch (int): Channel number of outputs. + num_feat (int): Channel number of intermediate features. + Default: 64 + num_block (int): Block number in the trunk network. Defaults: 23 + num_grow_ch (int): Channels for each growth. Default: 32. + """ + + def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32): + super(RRDBNet, self).__init__() + self.scale = scale + if scale == 2: + num_in_ch = num_in_ch * 4 + elif scale == 1: + num_in_ch = num_in_ch * 16 + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) + self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + # upsample + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + if self.scale == 2: + feat = pixel_unshuffle(x, scale=2) + elif self.scale == 1: + feat = pixel_unshuffle(x, scale=4) + else: + feat = x + feat = self.conv_first(feat) + body_feat = self.conv_body(self.body(feat)) + feat = feat + body_feat + # upsample + feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) + feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) + out = self.conv_last(self.lrelu(self.conv_hr(feat))) + return out diff --git a/face_vid2vid/LICENSE.md b/face_vid2vid/LICENSE.md new file mode 100644 index 0000000000000000000000000000000000000000..c8a17071020f90b31b40687e43180caa17fe4c37 --- /dev/null +++ b/face_vid2vid/LICENSE.md @@ -0,0 +1,161 @@ +## creative commons + +# Attribution-NonCommercial 4.0 International + +Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible. + +### Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses. + +* __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors). + +* __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees). + +## Creative Commons Attribution-NonCommercial 4.0 International Public License + +By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions. + +### Section 1 – Definitions. + +a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image. + +b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License. + +c. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights. + +d. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements. + +e. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material. + +f. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License. + +g. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license. + +h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License. + +i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange. + +j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them. + +k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world. + +l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning. + +### Section 2 – Scope. + +a. ___License grant.___ + + 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to: + + A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and + + B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only. + + 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions. + + 3. __Term.__ The term of this Public License is specified in Section 6(a). + + 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material. + + 5. __Downstream recipients.__ + + A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License. + + B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material. + + 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i). + +b. ___Other rights.___ + + 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this Public License. + + 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes. + +### Section 3 – License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the following conditions. + +a. ___Attribution.___ + + 1. If You Share the Licensed Material (including in modified form), You must: + + A. retain the following if it is supplied by the Licensor with the Licensed Material: + + i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of warranties; + + v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable; + + B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and + + C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information. + + 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License. + +### Section 4 – Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material: + +a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only; + +b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and + +c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights. + +### Section 5 – Disclaimer of Warranties and Limitation of Liability. + +a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__ + +b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__ + +c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability. + +### Section 6 – Term and Termination. + +a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically. + +b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License. + +c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License. + +d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License. + +### Section 7 – Other Terms and Conditions. + +a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. + +b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License. + +### Section 8 – Interpretation. + +a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License. + +b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions. + +c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor. + +d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. + +> Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses. +> +> Creative Commons may be contacted at creativecommons.org diff --git a/face_vid2vid/README.md b/face_vid2vid/README.md new file mode 100644 index 0000000000000000000000000000000000000000..aa1d6b6b9dceb7680b205d45c6dc35a108d6ac49 --- /dev/null +++ b/face_vid2vid/README.md @@ -0,0 +1,64 @@ +# One-Shot Free-View Neural Talking Head Synthesis +Unofficial pytorch implementation of paper "One-Shot Free-View Neural Talking-Head Synthesis for Video Conferencing". + +```Python 3.6``` and ```Pytorch 1.7``` are used. + + +Updates: +-------- +```2021.11.05``` : +* Replace Jacobian with the rotation matrix (Assuming J = R) to avoid estimating Jacobian. +* Correct the rotation matrix. + +```2021.11.17``` : +* Better Generator, better performance (models and checkpoints have been released). + +Driving | Beta Version | FOMM | New Version: + + +https://user-images.githubusercontent.com/17874285/142828000-db7b324e-c2fd-4fdc-a272-04fb8adbc88a.mp4 + + +-------- +Driving | FOMM | Ours: +![show](https://github.com/zhanglonghao1992/ReadmeImages/blob/master/images/081.gif) + +Free-View: +![show](https://github.com/zhanglonghao1992/ReadmeImages/blob/master/images/concat.gif) + +Train: +-------- +``` +python run.py --config config/vox-256.yaml --device_ids 0,1,2,3,4,5,6,7 +``` + +Demo: +-------- +``` +python demo.py --config config/vox-256.yaml --checkpoint path/to/checkpoint --source_image path/to/source --driving_video path/to/driving --relative --adapt_scale --find_best_frame +``` +free-view (e.g. yaw=20, pitch=roll=0): +``` +python demo.py --config config/vox-256.yaml --checkpoint path/to/checkpoint --source_image path/to/source --driving_video path/to/driving --relative --adapt_scale --find_best_frame --free_view --yaw 20 --pitch 0 --roll 0 +``` +Note: run ```crop-video.py --inp driving_video.mp4``` first to get the cropping suggestion and crop the raw video. + +Pretrained Model: +-------- + + Model | Train Set | Baidu Netdisk | Media Fire | + ------- |------------ |----------- |-------- | + Vox-256-Beta| VoxCeleb-v1 | [Baidu](https://pan.baidu.com/s/1lLS4ArbK2yWelsL-EtwU8g) (PW: c0tc)| [MF](https://www.mediafire.com/folder/rw51an7tk7bh2/TalkingHead) | + Vox-256-New | VoxCeleb-v1 | - | [MF](https://www.mediafire.com/folder/fcvtkn21j57bb/TalkingHead_Update) | + Vox-512 | VoxCeleb-v2 | soon | soon | + + Note: + 1. For now, the Beta Version is not well tuned. + 2. For free-view synthesis, it is recommended that Yaw, Pitch and Roll are within ±45°, ±20° and ±20° respectively. + 3. Face Restoration algorithms ([GPEN](https://github.com/yangxy/GPEN)) can be used for post-processing to significantly improve the resolution. +![show](https://github.com/zhanglonghao1992/ReadmeImages/blob/master/images/s%20r.gif) + + +Acknowlegement: +-------- +Thanks to [NV](https://github.com/NVlabs/face-vid2vid), [AliaksandrSiarohin](https://github.com/AliaksandrSiarohin/first-order-model) and [DeepHeadPose](https://github.com/DriverDistraction/DeepHeadPose). diff --git a/face_vid2vid/__pycache__/animate.cpython-310.pyc b/face_vid2vid/__pycache__/animate.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..714f46038801eb77fdc2d7500d660749fc143ca0 Binary files /dev/null and b/face_vid2vid/__pycache__/animate.cpython-310.pyc differ diff --git a/face_vid2vid/animate.py b/face_vid2vid/animate.py new file mode 100644 index 0000000000000000000000000000000000000000..53429e699dc212f324676c60297893e45a72e773 --- /dev/null +++ b/face_vid2vid/animate.py @@ -0,0 +1,33 @@ +import os +from tqdm import tqdm + +import torch +from torch.utils.data import DataLoader + +import imageio +from scipy.spatial import ConvexHull +import numpy as np + +from face_vid2vid.sync_batchnorm.replicate import DataParallelWithCallback + +def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False, + use_relative_movement=False, use_relative_jacobian=False): + if adapt_movement_scale: + source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume + driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume + adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area) + else: + adapt_movement_scale = 1 + + kp_new = {k: v for k, v in kp_driving.items()} + + if use_relative_movement: + kp_value_diff = (kp_driving['value'] - kp_driving_initial['value']) + kp_value_diff *= adapt_movement_scale + kp_new['value'] = kp_value_diff + kp_source['value'] + + if use_relative_jacobian: + jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian'])) + kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian']) + + return kp_new diff --git a/face_vid2vid/augmentation.py b/face_vid2vid/augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..50d03203aaec2a59fb2671bdeccfae1d214f607c --- /dev/null +++ b/face_vid2vid/augmentation.py @@ -0,0 +1,345 @@ +""" +Code from https://github.com/hassony2/torch_videovision +""" + +import numbers + +import random +import numpy as np +import PIL + +from skimage.transform import resize, rotate +from skimage.util import pad +import torchvision + +import warnings + +from skimage import img_as_ubyte, img_as_float + + +def crop_clip(clip, min_h, min_w, h, w): + if isinstance(clip[0], np.ndarray): + cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] + + elif isinstance(clip[0], PIL.Image.Image): + cropped = [ + img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip + ] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + return cropped + + +def pad_clip(clip, h, w): + im_h, im_w = clip[0].shape[:2] + pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2) + pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2) + + return pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge') + + +def resize_clip(clip, size, interpolation='bilinear'): + if isinstance(clip[0], np.ndarray): + if isinstance(size, numbers.Number): + im_h, im_w, im_c = clip[0].shape + # Min spatial dim already matches minimal size + if (im_w <= im_h and im_w == size) or (im_h <= im_w + and im_h == size): + return clip + new_h, new_w = get_resize_sizes(im_h, im_w, size) + size = (new_w, new_h) + else: + size = size[1], size[0] + + scaled = [ + resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True, + mode='constant', anti_aliasing=True) for img in clip + ] + elif isinstance(clip[0], PIL.Image.Image): + if isinstance(size, numbers.Number): + im_w, im_h = clip[0].size + # Min spatial dim already matches minimal size + if (im_w <= im_h and im_w == size) or (im_h <= im_w + and im_h == size): + return clip + new_h, new_w = get_resize_sizes(im_h, im_w, size) + size = (new_w, new_h) + else: + size = size[1], size[0] + if interpolation == 'bilinear': + pil_inter = PIL.Image.NEAREST + else: + pil_inter = PIL.Image.BILINEAR + scaled = [img.resize(size, pil_inter) for img in clip] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + return scaled + + +def get_resize_sizes(im_h, im_w, size): + if im_w < im_h: + ow = size + oh = int(size * im_h / im_w) + else: + oh = size + ow = int(size * im_w / im_h) + return oh, ow + + +class RandomFlip(object): + def __init__(self, time_flip=False, horizontal_flip=False): + self.time_flip = time_flip + self.horizontal_flip = horizontal_flip + + def __call__(self, clip): + if random.random() < 0.5 and self.time_flip: + return clip[::-1] + if random.random() < 0.5 and self.horizontal_flip: + return [np.fliplr(img) for img in clip] + + return clip + + +class RandomResize(object): + """Resizes a list of (H x W x C) numpy.ndarray to the final size + The larger the original image is, the more times it takes to + interpolate + Args: + interpolation (str): Can be one of 'nearest', 'bilinear' + defaults to nearest + size (tuple): (widht, height) + """ + + def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'): + self.ratio = ratio + self.interpolation = interpolation + + def __call__(self, clip): + scaling_factor = random.uniform(self.ratio[0], self.ratio[1]) + + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + + new_w = int(im_w * scaling_factor) + new_h = int(im_h * scaling_factor) + new_size = (new_w, new_h) + resized = resize_clip( + clip, new_size, interpolation=self.interpolation) + + return resized + + +class RandomCrop(object): + """Extract random crop at the same location for a list of videos + Args: + size (sequence or int): Desired output size for the + crop in format (h, w) + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + size = (size, size) + + self.size = size + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of videos to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of videos + """ + h, w = self.size + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + + clip = pad_clip(clip, h, w) + im_h, im_w = clip.shape[1:3] + x1 = 0 if h == im_h else random.randint(0, im_w - w) + y1 = 0 if w == im_w else random.randint(0, im_h - h) + cropped = crop_clip(clip, y1, x1, h, w) + + return cropped + + +class RandomRotation(object): + """Rotate entire clip randomly by a random angle within + given bounds + Args: + degrees (sequence or int): Range of degrees to select from + If degrees is a number instead of sequence like (min, max), + the range of degrees, will be (-degrees, +degrees). + """ + + def __init__(self, degrees): + if isinstance(degrees, numbers.Number): + if degrees < 0: + raise ValueError('If degrees is a single number,' + 'must be positive') + degrees = (-degrees, degrees) + else: + if len(degrees) != 2: + raise ValueError('If degrees is a sequence,' + 'it must be of len 2.') + + self.degrees = degrees + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of videos to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of videos + """ + angle = random.uniform(self.degrees[0], self.degrees[1]) + if isinstance(clip[0], np.ndarray): + rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip] + elif isinstance(clip[0], PIL.Image.Image): + rotated = [img.rotate(angle) for img in clip] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + + return rotated + + +class ColorJitter(object): + """Randomly change the brightness, contrast and saturation and hue of the clip + Args: + brightness (float): How much to jitter brightness. brightness_factor + is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. + contrast (float): How much to jitter contrast. contrast_factor + is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. + saturation (float): How much to jitter saturation. saturation_factor + is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. + hue(float): How much to jitter hue. hue_factor is chosen uniformly from + [-hue, hue]. Should be >=0 and <= 0.5. + """ + + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + self.brightness = brightness + self.contrast = contrast + self.saturation = saturation + self.hue = hue + + def get_params(self, brightness, contrast, saturation, hue): + if brightness > 0: + brightness_factor = random.uniform( + max(0, 1 - brightness), 1 + brightness) + else: + brightness_factor = None + + if contrast > 0: + contrast_factor = random.uniform( + max(0, 1 - contrast), 1 + contrast) + else: + contrast_factor = None + + if saturation > 0: + saturation_factor = random.uniform( + max(0, 1 - saturation), 1 + saturation) + else: + saturation_factor = None + + if hue > 0: + hue_factor = random.uniform(-hue, hue) + else: + hue_factor = None + return brightness_factor, contrast_factor, saturation_factor, hue_factor + + def __call__(self, clip): + """ + Args: + clip (list): list of PIL.Image + Returns: + list PIL.Image : list of transformed PIL.Image + """ + if isinstance(clip[0], np.ndarray): + brightness, contrast, saturation, hue = self.get_params( + self.brightness, self.contrast, self.saturation, self.hue) + + # Create img transform function sequence + img_transforms = [] + if brightness is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) + if saturation is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) + if hue is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) + if contrast is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) + random.shuffle(img_transforms) + img_transforms = [img_as_ubyte, torchvision.transforms.ToPILImage()] + img_transforms + [np.array, + img_as_float] + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + jittered_clip = [] + for img in clip: + jittered_img = img + for func in img_transforms: + jittered_img = func(jittered_img) + jittered_clip.append(jittered_img.astype('float32')) + elif isinstance(clip[0], PIL.Image.Image): + brightness, contrast, saturation, hue = self.get_params( + self.brightness, self.contrast, self.saturation, self.hue) + + # Create img transform function sequence + img_transforms = [] + if brightness is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) + if saturation is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) + if hue is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) + if contrast is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) + random.shuffle(img_transforms) + + # Apply to all videos + jittered_clip = [] + for img in clip: + for func in img_transforms: + jittered_img = func(img) + jittered_clip.append(jittered_img) + + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + return jittered_clip + + +class AllAugmentationTransform: + def __init__(self, resize_param=None, rotation_param=None, flip_param=None, crop_param=None, jitter_param=None): + self.transforms = [] + + if flip_param is not None: + self.transforms.append(RandomFlip(**flip_param)) + + if rotation_param is not None: + self.transforms.append(RandomRotation(**rotation_param)) + + if resize_param is not None: + self.transforms.append(RandomResize(**resize_param)) + + if crop_param is not None: + self.transforms.append(RandomCrop(**crop_param)) + + if jitter_param is not None: + self.transforms.append(ColorJitter(**jitter_param)) + + def __call__(self, clip): + for t in self.transforms: + clip = t(clip) + return clip diff --git a/face_vid2vid/config/vox-256-spade.yaml b/face_vid2vid/config/vox-256-spade.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e075a4ba2e8081f05c43b7ea9d934bc66239c5f7 --- /dev/null +++ b/face_vid2vid/config/vox-256-spade.yaml @@ -0,0 +1,88 @@ +dataset_params: + root_dir: + frame_shape: [256, 256, 3] + id_sampling: True + pairs_list: None + augmentation_params: + flip_param: + horizontal_flip: True + time_flip: True + jitter_param: + brightness: 0.1 + contrast: 0.1 + saturation: 0.1 + hue: 0.1 + + +model_params: + common_params: + num_kp: 15 + image_channel: 3 + feature_channel: 32 + estimate_jacobian: False + kp_detector_params: + temperature: 0.1 + block_expansion: 32 + max_features: 1024 + scale_factor: 0.25 + num_blocks: 5 + reshape_channel: 16384 # 16384 = 1024 * 16 + reshape_depth: 16 + he_estimator_params: + block_expansion: 64 + max_features: 2048 + num_bins: 66 + generator_params: + block_expansion: 64 + max_features: 512 + num_down_blocks: 2 + reshape_channel: 32 + reshape_depth: 16 # 512 = 32 * 16 + num_resblocks: 6 + estimate_occlusion_map: True + dense_motion_params: + block_expansion: 32 + max_features: 1024 + num_blocks: 5 + # reshape_channel: 32 + reshape_depth: 16 + compress: 4 + discriminator_params: + scales: [1] + block_expansion: 32 + max_features: 512 + num_blocks: 4 + sn: True + +train_params: + num_epochs: 200 + num_repeats: 75 + epoch_milestones: [180,] + lr_generator: 2.0e-4 + lr_discriminator: 2.0e-4 + lr_kp_detector: 2.0e-4 + lr_he_estimator: 2.0e-4 + gan_mode: 'hinge' # hinge or ls + batch_size: 1 + scales: [1, 0.5, 0.25, 0.125] + checkpoint_freq: 60 + hopenet_snapshot: './checkpoints/hopenet_robust_alpha1.pkl' + transform_params: + sigma_affine: 0.05 + sigma_tps: 0.005 + points_tps: 5 + loss_weights: + generator_gan: 1 + discriminator_gan: 1 + feature_matching: [10, 10, 10, 10] + perceptual: [10, 10, 10, 10, 10] + equivariance_value: 10 + equivariance_jacobian: 0 + keypoint: 10 + headpose: 20 + expression: 5 + +visualizer_params: + kp_size: 5 + draw_border: True + colormap: 'gist_rainbow' diff --git a/face_vid2vid/crop-video.py b/face_vid2vid/crop-video.py new file mode 100644 index 0000000000000000000000000000000000000000..1a7740ee151ed104f4da887ac41b42dd693da2cc --- /dev/null +++ b/face_vid2vid/crop-video.py @@ -0,0 +1,158 @@ +import face_alignment +import skimage.io +import numpy +from argparse import ArgumentParser +from skimage import img_as_ubyte +from skimage.transform import resize +from tqdm import tqdm +import os +import imageio +import numpy as np +import warnings +warnings.filterwarnings("ignore") + +def extract_bbox(frame, fa): + if max(frame.shape[0], frame.shape[1]) > 640: + scale_factor = max(frame.shape[0], frame.shape[1]) / 640.0 + frame = resize(frame, (int(frame.shape[0] / scale_factor), int(frame.shape[1] / scale_factor))) + frame = img_as_ubyte(frame) + else: + scale_factor = 1 + frame = frame[..., :3] + bboxes = fa.face_detector.detect_from_image(frame[..., ::-1]) + if len(bboxes) == 0: + return [] + return np.array(bboxes)[:, :-1] * scale_factor + + + +def bb_intersection_over_union(boxA, boxB): + xA = max(boxA[0], boxB[0]) + yA = max(boxA[1], boxB[1]) + xB = min(boxA[2], boxB[2]) + yB = min(boxA[3], boxB[3]) + interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1) + boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1) + boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1) + iou = interArea / float(boxAArea + boxBArea - interArea) + return iou + + +def join(tube_bbox, bbox): + xA = min(tube_bbox[0], bbox[0]) + yA = min(tube_bbox[1], bbox[1]) + xB = max(tube_bbox[2], bbox[2]) + yB = max(tube_bbox[3], bbox[3]) + return (xA, yA, xB, yB) + + +def compute_bbox(start, end, fps, tube_bbox, frame_shape, inp, image_shape, increase_area=0.1): + left, top, right, bot = tube_bbox + width = right - left + height = bot - top + + #Computing aspect preserving bbox + width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width)) + height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height)) + + left = int(left - width_increase * width) + top = int(top - height_increase * height) + right = int(right + width_increase * width) + bot = int(bot + height_increase * height) + + top, bot, left, right = max(0, top), min(bot, frame_shape[0]), max(0, left), min(right, frame_shape[1]) + h, w = bot - top, right - left + + start = start / fps + end = end / fps + time = end - start + + scale = f'{image_shape[0]}:{image_shape[1]}' + + return f'ffmpeg -i {inp} -ss {start} -t {time} -filter:v "crop={w}:{h}:{left}:{top}, scale={scale}" crop.mp4' + + +def compute_bbox_trajectories(trajectories, fps, frame_shape, args): + commands = [] + for i, (bbox, tube_bbox, start, end) in enumerate(trajectories): + if (end - start) > args.min_frames: + command = compute_bbox(start, end, fps, tube_bbox, frame_shape, inp=args.inp, image_shape=args.image_shape, increase_area=args.increase) + commands.append(command) + return commands + + +def process_video(args): + device = 'cpu' if args.cpu else 'cuda' + fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False, device=device) + video = imageio.get_reader(args.inp) + + trajectories = [] + previous_frame = None + fps = video.get_meta_data()['fps'] + commands = [] + try: + for i, frame in tqdm(enumerate(video)): + frame_shape = frame.shape + bboxes = extract_bbox(frame, fa) + ## For each trajectory check the criterion + not_valid_trajectories = [] + valid_trajectories = [] + + for trajectory in trajectories: + tube_bbox = trajectory[0] + intersection = 0 + for bbox in bboxes: + intersection = max(intersection, bb_intersection_over_union(tube_bbox, bbox)) + if intersection > args.iou_with_initial: + valid_trajectories.append(trajectory) + else: + not_valid_trajectories.append(trajectory) + + commands += compute_bbox_trajectories(not_valid_trajectories, fps, frame_shape, args) + trajectories = valid_trajectories + + ## Assign bbox to trajectories, create new trajectories + for bbox in bboxes: + intersection = 0 + current_trajectory = None + for trajectory in trajectories: + tube_bbox = trajectory[0] + current_intersection = bb_intersection_over_union(tube_bbox, bbox) + if intersection < current_intersection and current_intersection > args.iou_with_initial: + intersection = bb_intersection_over_union(tube_bbox, bbox) + current_trajectory = trajectory + + ## Create new trajectory + if current_trajectory is None: + trajectories.append([bbox, bbox, i, i]) + else: + current_trajectory[3] = i + current_trajectory[1] = join(current_trajectory[1], bbox) + + + except IndexError as e: + raise (e) + + commands += compute_bbox_trajectories(trajectories, fps, frame_shape, args) + return commands + + +if __name__ == "__main__": + parser = ArgumentParser() + + parser.add_argument("--image_shape", default=(256, 256), type=lambda x: tuple(map(int, x.split(','))), + help="Image shape") + parser.add_argument("--increase", default=0.1, type=float, help='Increase bbox by this amount') + parser.add_argument("--iou_with_initial", type=float, default=0.25, help="The minimal allowed iou with inital bbox") + parser.add_argument("--inp", required=True, help='Input image or video') + parser.add_argument("--min_frames", type=int, default=150, help='Minimum number of frames') + parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.") + + + args = parser.parse_args() + + commands = process_video(args) + for command in commands: + print (command) + + \ No newline at end of file diff --git a/face_vid2vid/demo.py b/face_vid2vid/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..3ecd3a8f4a5696fb916efe70dad256bedfcc95ab --- /dev/null +++ b/face_vid2vid/demo.py @@ -0,0 +1,304 @@ +import matplotlib +matplotlib.use('Agg') +import os, sys +import yaml +from argparse import ArgumentParser +from tqdm import tqdm + +import imageio +import numpy as np +from skimage.transform import resize +from skimage import img_as_ubyte +import torch +import torch.nn.functional as F +from sync_batchnorm import DataParallelWithCallback + +from modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator +from modules.keypoint_detector import KPDetector, HEEstimator +from animate import normalize_kp +from scipy.spatial import ConvexHull + + +if sys.version_info[0] < 3: + raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7") + +def load_checkpoints(config_path, checkpoint_path, gen, cpu=False): + + with open(config_path) as f: + config = yaml.load(f) + + if gen == 'original': + generator = OcclusionAwareGenerator(**config['model_params']['generator_params'], + **config['model_params']['common_params']) + elif gen == 'spade': + generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'], + **config['model_params']['common_params']) + + if not cpu: + generator.cuda() + + kp_detector = KPDetector(**config['model_params']['kp_detector_params'], + **config['model_params']['common_params']) + if not cpu: + kp_detector.cuda() + + he_estimator = HEEstimator(**config['model_params']['he_estimator_params'], + **config['model_params']['common_params']) + if not cpu: + he_estimator.cuda() + + if cpu: + checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) + else: + checkpoint = torch.load(checkpoint_path) + + generator.load_state_dict(checkpoint['generator']) + kp_detector.load_state_dict(checkpoint['kp_detector']) + he_estimator.load_state_dict(checkpoint['he_estimator']) + + if not cpu: + generator = DataParallelWithCallback(generator) + kp_detector = DataParallelWithCallback(kp_detector) + he_estimator = DataParallelWithCallback(he_estimator) + + generator.eval() + kp_detector.eval() + he_estimator.eval() + + return generator, kp_detector, he_estimator + + +def headpose_pred_to_degree(pred): + device = pred.device + idx_tensor = [idx for idx in range(66)] + idx_tensor = torch.FloatTensor(idx_tensor).to(device) + pred = F.softmax(pred) + degree = torch.sum(pred*idx_tensor, axis=1) * 3 - 99 + + return degree + +''' +# beta version +def get_rotation_matrix(yaw, pitch, roll): + yaw = yaw / 180 * 3.14 + pitch = pitch / 180 * 3.14 + roll = roll / 180 * 3.14 + + roll = roll.unsqueeze(1) + pitch = pitch.unsqueeze(1) + yaw = yaw.unsqueeze(1) + + roll_mat = torch.cat([torch.ones_like(roll), torch.zeros_like(roll), torch.zeros_like(roll), + torch.zeros_like(roll), torch.cos(roll), -torch.sin(roll), + torch.zeros_like(roll), torch.sin(roll), torch.cos(roll)], dim=1) + roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3) + + pitch_mat = torch.cat([torch.cos(pitch), torch.zeros_like(pitch), torch.sin(pitch), + torch.zeros_like(pitch), torch.ones_like(pitch), torch.zeros_like(pitch), + -torch.sin(pitch), torch.zeros_like(pitch), torch.cos(pitch)], dim=1) + pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3) + + yaw_mat = torch.cat([torch.cos(yaw), -torch.sin(yaw), torch.zeros_like(yaw), + torch.sin(yaw), torch.cos(yaw), torch.zeros_like(yaw), + torch.zeros_like(yaw), torch.zeros_like(yaw), torch.ones_like(yaw)], dim=1) + yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3) + + rot_mat = torch.einsum('bij,bjk,bkm->bim', roll_mat, pitch_mat, yaw_mat) + + return rot_mat + +''' +def get_rotation_matrix(yaw, pitch, roll): + yaw = yaw / 180 * 3.14 + pitch = pitch / 180 * 3.14 + roll = roll / 180 * 3.14 + + roll = roll.unsqueeze(1) + pitch = pitch.unsqueeze(1) + yaw = yaw.unsqueeze(1) + + pitch_mat = torch.cat([torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch), + torch.zeros_like(pitch), torch.cos(pitch), -torch.sin(pitch), + torch.zeros_like(pitch), torch.sin(pitch), torch.cos(pitch)], dim=1) + pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3) + + yaw_mat = torch.cat([torch.cos(yaw), torch.zeros_like(yaw), torch.sin(yaw), + torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw), + -torch.sin(yaw), torch.zeros_like(yaw), torch.cos(yaw)], dim=1) + yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3) + + roll_mat = torch.cat([torch.cos(roll), -torch.sin(roll), torch.zeros_like(roll), + torch.sin(roll), torch.cos(roll), torch.zeros_like(roll), + torch.zeros_like(roll), torch.zeros_like(roll), torch.ones_like(roll)], dim=1) + roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3) + + rot_mat = torch.einsum('bij,bjk,bkm->bim', pitch_mat, yaw_mat, roll_mat) + + return rot_mat + +def keypoint_transformation(kp_canonical, he, estimate_jacobian=True, free_view=False, yaw=0, pitch=0, roll=0): + kp = kp_canonical['value'] + if not free_view: + yaw, pitch, roll = he['yaw'], he['pitch'], he['roll'] + yaw = headpose_pred_to_degree(yaw) + pitch = headpose_pred_to_degree(pitch) + roll = headpose_pred_to_degree(roll) + else: + if yaw is not None: + yaw = torch.tensor([yaw]).cuda() + else: + yaw = he['yaw'] + yaw = headpose_pred_to_degree(yaw) + if pitch is not None: + pitch = torch.tensor([pitch]).cuda() + else: + pitch = he['pitch'] + pitch = headpose_pred_to_degree(pitch) + if roll is not None: + roll = torch.tensor([roll]).cuda() + else: + roll = he['roll'] + roll = headpose_pred_to_degree(roll) + + t, exp = he['t'], he['exp'] + + rot_mat = get_rotation_matrix(yaw, pitch, roll) + + # keypoint rotation + kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp) + + # keypoint translation + t = t.unsqueeze_(1).repeat(1, kp.shape[1], 1) + kp_t = kp_rotated + t + + # add expression deviation + exp = exp.view(exp.shape[0], -1, 3) + kp_transformed = kp_t + exp + + if estimate_jacobian: + jacobian = kp_canonical['jacobian'] + jacobian_transformed = torch.einsum('bmp,bkps->bkms', rot_mat, jacobian) + else: + jacobian_transformed = None + + return {'value': kp_transformed, 'jacobian': jacobian_transformed} + +def make_animation(source_image, driving_video, generator, kp_detector, he_estimator, relative=True, adapt_movement_scale=True, estimate_jacobian=True, cpu=False, free_view=False, yaw=0, pitch=0, roll=0): + with torch.no_grad(): + predictions = [] + source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2) + if not cpu: + source = source.cuda() + driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3) + kp_canonical = kp_detector(source) + he_source = he_estimator(source) + he_driving_initial = he_estimator(driving[:, :, 0]) + + kp_source = keypoint_transformation(kp_canonical, he_source, estimate_jacobian) + kp_driving_initial = keypoint_transformation(kp_canonical, he_driving_initial, estimate_jacobian) + # kp_driving_initial = keypoint_transformation(kp_canonical, he_driving_initial, free_view=free_view, yaw=yaw, pitch=pitch, roll=roll) + + for frame_idx in tqdm(range(driving.shape[2])): + driving_frame = driving[:, :, frame_idx] + if not cpu: + driving_frame = driving_frame.cuda() + he_driving = he_estimator(driving_frame) + kp_driving = keypoint_transformation(kp_canonical, he_driving, estimate_jacobian, free_view=free_view, yaw=yaw, pitch=pitch, roll=roll) + kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving, + kp_driving_initial=kp_driving_initial, use_relative_movement=relative, + use_relative_jacobian=estimate_jacobian, adapt_movement_scale=adapt_movement_scale) + out = generator(source, kp_source=kp_source, kp_driving=kp_norm) + + predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]) + return predictions + +def find_best_frame(source, driving, cpu=False): + import face_alignment + + def normalize_kp(kp): + kp = kp - kp.mean(axis=0, keepdims=True) + area = ConvexHull(kp[:, :2]).volume + area = np.sqrt(area) + kp[:, :2] = kp[:, :2] / area + return kp + + fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True, + device='cpu' if cpu else 'cuda') + kp_source = fa.get_landmarks(255 * source)[0] + kp_source = normalize_kp(kp_source) + norm = float('inf') + frame_num = 0 + for i, image in tqdm(enumerate(driving)): + kp_driving = fa.get_landmarks(255 * image)[0] + kp_driving = normalize_kp(kp_driving) + new_norm = (np.abs(kp_source - kp_driving) ** 2).sum() + if new_norm < norm: + norm = new_norm + frame_num = i + return frame_num + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--config", default='config/vox-256.yaml', help="path to config") + parser.add_argument("--checkpoint", default='', help="path to checkpoint to restore") + + parser.add_argument("--source_image", default='', help="path to source image") + parser.add_argument("--driving_video", default='', help="path to driving video") + parser.add_argument("--result_video", default='', help="path to output") + + parser.add_argument("--gen", default="spade", choices=["original", "spade"]) + + parser.add_argument("--relative", dest="relative", action="store_true", help="use relative or absolute keypoint coordinates") + parser.add_argument("--adapt_scale", dest="adapt_scale", action="store_true", help="adapt movement scale based on convex hull of keypoints") + + parser.add_argument("--find_best_frame", dest="find_best_frame", action="store_true", + help="Generate from the frame that is the most alligned with source. (Only for faces, requires face_aligment lib)") + + parser.add_argument("--best_frame", dest="best_frame", type=int, default=None, + help="Set frame to start from.") + + parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.") + + parser.add_argument("--free_view", dest="free_view", action="store_true", help="control head pose") + parser.add_argument("--yaw", dest="yaw", type=int, default=None, help="yaw") + parser.add_argument("--pitch", dest="pitch", type=int, default=None, help="pitch") + parser.add_argument("--roll", dest="roll", type=int, default=None, help="roll") + + + parser.set_defaults(relative=False) + parser.set_defaults(adapt_scale=False) + parser.set_defaults(free_view=False) + + opt = parser.parse_args() + + source_image = imageio.imread(opt.source_image) + reader = imageio.get_reader(opt.driving_video) + fps = reader.get_meta_data()['fps'] + driving_video = [] + try: + for im in reader: + driving_video.append(im) + except RuntimeError: + pass + reader.close() + + source_image = resize(source_image, (256, 256))[..., :3] + driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video] + generator, kp_detector, he_estimator = load_checkpoints(config_path=opt.config, checkpoint_path=opt.checkpoint, gen=opt.gen, cpu=opt.cpu) + + with open(opt.config) as f: + config = yaml.load(f) + estimate_jacobian = config['model_params']['common_params']['estimate_jacobian'] + print(f'estimate jacobian: {estimate_jacobian}') + + if opt.find_best_frame or opt.best_frame is not None: + i = opt.best_frame if opt.best_frame is not None else find_best_frame(source_image, driving_video, cpu=opt.cpu) + print ("Best frame: " + str(i)) + driving_forward = driving_video[i:] + driving_backward = driving_video[:(i+1)][::-1] + predictions_forward = make_animation(source_image, driving_forward, generator, kp_detector, he_estimator, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, estimate_jacobian=estimate_jacobian, cpu=opt.cpu, free_view=opt.free_view, yaw=opt.yaw, pitch=opt.pitch, roll=opt.roll) + predictions_backward = make_animation(source_image, driving_backward, generator, kp_detector, he_estimator, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, estimate_jacobian=estimate_jacobian, cpu=opt.cpu, free_view=opt.free_view, yaw=opt.yaw, pitch=opt.pitch, roll=opt.roll) + predictions = predictions_backward[::-1] + predictions_forward[1:] + else: + predictions = make_animation(source_image, driving_video, generator, kp_detector, he_estimator, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, estimate_jacobian=estimate_jacobian, cpu=opt.cpu, free_view=opt.free_view, yaw=opt.yaw, pitch=opt.pitch, roll=opt.roll) + imageio.mimsave(opt.result_video, [img_as_ubyte(frame) for frame in predictions], fps=fps) diff --git a/face_vid2vid/demo_utils.py b/face_vid2vid/demo_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..243a1bd943a1d7533ffaad20aadc8123de39d061 --- /dev/null +++ b/face_vid2vid/demo_utils.py @@ -0,0 +1,368 @@ +import os +import sys +import cv2 +import yaml +import imageio +import numpy as np +import torch +import torch.nn.functional as F + + +sys.path.append("./face-vid2vid") +from sync_batchnorm import DataParallelWithCallback +from modules.generator import OcclusionAwareSPADEGenerator +from modules.keypoint_detector import KPDetector, HEEstimator +from animate import normalize_kp +from batch_face import RetinaFace + + +if sys.version_info[0] < 3: + raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7") + + +def load_checkpoints(config_path, checkpoint_path): + with open(config_path) as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + generator = OcclusionAwareSPADEGenerator(**config["model_params"]["generator_params"], **config["model_params"]["common_params"]) + # convert to half precision to speed up + generator.cuda().half() + + kp_detector = KPDetector(**config["model_params"]["kp_detector_params"], **config["model_params"]["common_params"]) + # the result will be wrong if converted to half precision, not sure why + kp_detector.cuda() # .half() + + he_estimator = HEEstimator(**config["model_params"]["he_estimator_params"], **config["model_params"]["common_params"]) + # the result will be wrong if converted to half precision, not sure why + he_estimator.cuda() # .half() + + print("Loading checkpoints") + checkpoint = torch.load(checkpoint_path) + + generator.load_state_dict(checkpoint["generator"]) + kp_detector.load_state_dict(checkpoint["kp_detector"]) + he_estimator.load_state_dict(checkpoint["he_estimator"]) + + generator = DataParallelWithCallback(generator) + kp_detector = DataParallelWithCallback(kp_detector) + he_estimator = DataParallelWithCallback(he_estimator) + + generator.eval() + kp_detector.eval() + he_estimator.eval() + print("Model successfully loaded!") + + return generator, kp_detector, he_estimator + + +def headpose_pred_to_degree(pred): + device = pred.device + idx_tensor = [idx for idx in range(66)] + idx_tensor = torch.FloatTensor(idx_tensor).to(device) + pred = F.softmax(pred, dim=1) + degree = torch.sum(pred * idx_tensor, axis=1) * 3 - 99 + + return degree + + +def get_rotation_matrix(yaw, pitch, roll): + yaw = yaw / 180 * 3.14 + pitch = pitch / 180 * 3.14 + roll = roll / 180 * 3.14 + + roll = roll.unsqueeze(1) + pitch = pitch.unsqueeze(1) + yaw = yaw.unsqueeze(1) + + pitch_mat = torch.cat( + [ + torch.ones_like(pitch), + torch.zeros_like(pitch), + torch.zeros_like(pitch), + torch.zeros_like(pitch), + torch.cos(pitch), + -torch.sin(pitch), + torch.zeros_like(pitch), + torch.sin(pitch), + torch.cos(pitch), + ], + dim=1, + ) + pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3) + + yaw_mat = torch.cat( + [ + torch.cos(yaw), + torch.zeros_like(yaw), + torch.sin(yaw), + torch.zeros_like(yaw), + torch.ones_like(yaw), + torch.zeros_like(yaw), + -torch.sin(yaw), + torch.zeros_like(yaw), + torch.cos(yaw), + ], + dim=1, + ) + yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3) + + roll_mat = torch.cat( + [ + torch.cos(roll), + -torch.sin(roll), + torch.zeros_like(roll), + torch.sin(roll), + torch.cos(roll), + torch.zeros_like(roll), + torch.zeros_like(roll), + torch.zeros_like(roll), + torch.ones_like(roll), + ], + dim=1, + ) + roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3) + + rot_mat = torch.einsum("bij,bjk,bkm->bim", pitch_mat, yaw_mat, roll_mat) + + return rot_mat + + +def keypoint_transformation(kp_canonical, he, estimate_jacobian=False, free_view=False, yaw=0, pitch=0, roll=0, output_coord=False): + kp = kp_canonical["value"] + if not free_view: + yaw, pitch, roll = he["yaw"], he["pitch"], he["roll"] + yaw = headpose_pred_to_degree(yaw) + pitch = headpose_pred_to_degree(pitch) + roll = headpose_pred_to_degree(roll) + else: + if yaw is not None: + yaw = torch.tensor([yaw]).cuda() + else: + yaw = he["yaw"] + yaw = headpose_pred_to_degree(yaw) + if pitch is not None: + pitch = torch.tensor([pitch]).cuda() + else: + pitch = he["pitch"] + pitch = headpose_pred_to_degree(pitch) + if roll is not None: + roll = torch.tensor([roll]).cuda() + else: + roll = he["roll"] + roll = headpose_pred_to_degree(roll) + + t, exp = he["t"], he["exp"] + + rot_mat = get_rotation_matrix(yaw, pitch, roll) + + # keypoint rotation + kp_rotated = torch.einsum("bmp,bkp->bkm", rot_mat, kp) + + # keypoint translation + t = t.unsqueeze_(1).repeat(1, kp.shape[1], 1) + kp_t = kp_rotated + t + + # add expression deviation + exp = exp.view(exp.shape[0], -1, 3) + kp_transformed = kp_t + exp + + if estimate_jacobian: + jacobian = kp_canonical["jacobian"] + jacobian_transformed = torch.einsum("bmp,bkps->bkms", rot_mat, jacobian) + else: + jacobian_transformed = None + + if output_coord: + return {"value": kp_transformed, "jacobian": jacobian_transformed}, { + "yaw": float(yaw.cpu().numpy()), + "pitch": float(pitch.cpu().numpy()), + "roll": float(roll.cpu().numpy()), + } + + return {"value": kp_transformed, "jacobian": jacobian_transformed} + + +def get_square_face(coords, image): + x1, y1, x2, y2 = coords + # expand the face region by 1.5 times + length = max(x2 - x1, y2 - y1) // 2 + x1 = x1 - length * 0.5 + x2 = x2 + length * 0.5 + y1 = y1 - length * 0.5 + y2 = y2 + length * 0.5 + + # get square image + center = (x1 + x2) // 2, (y1 + y2) // 2 + length = max(x2 - x1, y2 - y1) // 2 + x1 = max(int(round(center[0] - length)), 0) + x2 = min(int(round(center[0] + length)), image.shape[1]) + y1 = max(int(round(center[1] - length)), 0) + y2 = min(int(round(center[1] + length)), image.shape[0]) + return image[y1:y2, x1:x2] + + +def smooth_coord(last_coord, current_coord, smooth_factor=0.2): + change = np.array(current_coord) - np.array(last_coord) + # smooth the change to 0.1 times + change = change * smooth_factor + return (np.array(last_coord) + np.array(change)).astype(int).tolist() + + +class FaceAnimationClass: + def __init__(self, source_image_path=None, use_sr=False): + assert source_image_path is not None, "source_image_path is None, please set source_image_path" + config_path = os.path.join(os.path.dirname(__file__), "face_vid2vid/config/vox-256-spade.yaml") + # save to local cache to speed loading + checkpoint_path = os.path.join(os.path.expanduser("~"), ".cache/torch/hub/checkpoints/FaceMapping.pth.tar") + if not os.path.exists(checkpoint_path): + os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) + from gdown import download + file_id = "11ZgyjKI5OcB7klcsIdPpCCX38AIX8Soc" + download(id=file_id, output=checkpoint_path, quiet=False) + if use_sr: + from face_vid2vid.GPEN.face_enhancement import FaceEnhancement + + self.faceenhancer = FaceEnhancement( + size=256, model="GPEN-BFR-256", use_sr=False, sr_model="realesrnet_x2", channel_multiplier=1, narrow=0.5, use_facegan=True + ) + + # load checkpoints + self.generator, self.kp_detector, self.he_estimator = load_checkpoints(config_path=config_path, checkpoint_path=checkpoint_path) + source_image = cv2.cvtColor(cv2.imread(source_image_path), cv2.COLOR_RGB2BGR).astype(np.float32) / 255. + source_image = cv2.resize(source_image, (256, 256), interpolation=cv2.INTER_AREA) + source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2) + self.source = source.cuda() + + # initilize face detectors + self.face_detector = RetinaFace() + self.detect_interval = 8 + self.smooth_factor = 0.2 + + # load base frame and blank frame + self.base_frame = cv2.imread(source_image_path) if not use_sr else self.faceenhancer.process(cv2.imread(source_image_path))[0] + self.base_frame = cv2.resize(self.base_frame, (256, 256)) + self.blank_frame = np.ones(self.base_frame.shape, dtype=np.uint8) * 255 + cv2.putText(self.blank_frame, "Face not", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) + cv2.putText(self.blank_frame, "detected!", (50, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) + + # count for frame + self.n_frame = 0 + + # initilize variables + self.first_frame = True + self.last_coords = None + self.coords = None + self.use_sr = use_sr + self.kp_source = None + self.kp_driving_initial = None + + + def _conver_input_frame(self, frame): + frame = cv2.resize(frame, (256, 256), interpolation=cv2.INTER_NEAREST).astype(np.float32) / 255.0 + return torch.tensor(frame[np.newaxis]).permute(0, 3, 1, 2).cuda() + + def _process_first_frame(self, frame): + print("Processing first frame") + # function to process the first frame + faces = self.face_detector(frame, cv=True) + if len(faces) == 0: + raise ValueError("Face is not detected") + else: + self.coords = faces[0][0] + face = get_square_face(self.coords, frame) + self.last_coords = self.coords + + # get the keypoint and headpose from the source image + with torch.no_grad(): + self.kp_canonical = self.kp_detector(self.source) + self.he_source = self.he_estimator(self.source) + + face_input = self._conver_input_frame(face) + he_driving_initial = self.he_estimator(face_input) + self.kp_driving_initial, coordinates = keypoint_transformation(self.kp_canonical, he_driving_initial, output_coord=True) + self.kp_source = keypoint_transformation( + self.kp_canonical, self.he_source, free_view=True, yaw=coordinates["yaw"], pitch=coordinates["pitch"], roll=coordinates["roll"] + ) + + def _inference(self, frame): + # function to process the rest frames + with torch.no_grad(): + self.n_frame += 1 + if self.first_frame: + self._process_first_frame(frame) + self.first_frame = False + else: + pass + if self.n_frame % self.detect_interval == 0: + faces = self.face_detector(frame, cv=True) + if len(faces) == 0: + raise ValueError("Face is not detected") + else: + self.coords = faces[0][0] + self.coords = smooth_coord(self.last_coords, self.coords, self.smooth_factor) + face = get_square_face(self.coords, frame) + self.last_coords = self.coords + face_input = self._conver_input_frame(face) + + he_driving = self.he_estimator(face_input) + kp_driving = keypoint_transformation(self.kp_canonical, he_driving) + kp_norm = normalize_kp( + kp_source=self.kp_source, + kp_driving=kp_driving, + kp_driving_initial=self.kp_driving_initial, + use_relative_movement=True, + adapt_movement_scale=True, + ) + + out = self.generator(self.source, kp_source=self.kp_source, kp_driving=kp_norm, fp16=True) + image = np.transpose(out["prediction"].data.cpu().numpy(), [0, 2, 3, 1])[0] + image = (np.array(image).astype(np.float32) * 255).astype(np.uint8) + result = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + return face, result + + def inference(self, frame): + # function to inference, input frame, output cropped face and its result + try: + if frame is not None: + face, result = self._inference(frame) + if self.use_sr: + result, _, _ = self.faceenhancer.process(result) + result = cv2.resize(result, (256, 256)) + return face, result + except Exception as e: + print(e) + self.first_frame = True + self.n_frame = 0 + return self.blank_frame, self.base_frame + + +if __name__ == "__main__": + from tqdm import tqdm + import time + faceanimation = FaceAnimationClass(source_image_path="tmp.png", use_sr=False) + + video_path = "driver.mp4" + capture = cv2.VideoCapture(video_path) + fps = capture.get(cv2.CAP_PROP_FPS) + frames = [] + _, frame = capture.read() + while frame is not None: + frames.append(frame) + _, frame = capture.read() + capture.release() + + output_frames = [] + time_start = time.time() + for frame in tqdm(frames): + face, result = faceanimation.inference(frame) + # show = cv2.hconcat([cv2.resize(face, (result.shape[1], result.shape[0])), result]) + output_frames.append(result) + time_end = time.time() + print("Time cost: %.2f" % (time_end - time_start), "FPS: %.2f" % (len(frames) / (time_end - time_start))) + writer = imageio.get_writer("result2.mp4", fps=fps, quality=9, macro_block_size=1, codec="libx264", pixelformat="yuv420p") + for frame in output_frames: + writer.append_data(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + # writer.append_data(frame) + writer.close() + print("Video saved to result2.mp4") diff --git a/face_vid2vid/frames_dataset.py b/face_vid2vid/frames_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b7101b5bbb03afbdae3582ec624cdef3209552be --- /dev/null +++ b/face_vid2vid/frames_dataset.py @@ -0,0 +1,154 @@ +import os +from skimage import io, img_as_float32 +from skimage.color import gray2rgb +from sklearn.model_selection import train_test_split +from imageio import mimread + +import numpy as np +from torch.utils.data import Dataset +import pandas as pd +from augmentation import AllAugmentationTransform +import glob + + +def read_video(name, frame_shape): + """ + Read video which can be: + - an image of concatenated frames + - '.mp4' and'.gif' + - folder with videos + """ + + if os.path.isdir(name): + frames = sorted(os.listdir(name)) + num_frames = len(frames) + video_array = np.array( + [img_as_float32(io.imread(os.path.join(name, frames[idx]))) for idx in range(num_frames)]) + elif name.lower().endswith('.png') or name.lower().endswith('.jpg'): + image = io.imread(name) + + if len(image.shape) == 2 or image.shape[2] == 1: + image = gray2rgb(image) + + if image.shape[2] == 4: + image = image[..., :3] + + image = img_as_float32(image) + + video_array = np.moveaxis(image, 1, 0) + + video_array = video_array.reshape((-1,) + frame_shape) + video_array = np.moveaxis(video_array, 1, 2) + elif name.lower().endswith('.gif') or name.lower().endswith('.mp4') or name.lower().endswith('.mov'): + video = np.array(mimread(name)) + if len(video.shape) == 3: + video = np.array([gray2rgb(frame) for frame in video]) + if video.shape[-1] == 4: + video = video[..., :3] + video_array = img_as_float32(video) + else: + raise Exception("Unknown file extensions %s" % name) + + return video_array + + +class FramesDataset(Dataset): + """ + Dataset of videos, each video can be represented as: + - an image of concatenated frames + - '.mp4' or '.gif' + - folder with all frames + """ + + def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True, + random_seed=0, pairs_list=None, augmentation_params=None): + self.root_dir = root_dir + self.videos = os.listdir(root_dir) + self.frame_shape = tuple(frame_shape) + self.pairs_list = pairs_list + self.id_sampling = id_sampling + if os.path.exists(os.path.join(root_dir, 'train')): + assert os.path.exists(os.path.join(root_dir, 'test')) + print("Use predefined train-test split.") + if id_sampling: + train_videos = {os.path.basename(video).split('#')[0] for video in + os.listdir(os.path.join(root_dir, 'train'))} + train_videos = list(train_videos) + else: + train_videos = os.listdir(os.path.join(root_dir, 'train')) + test_videos = os.listdir(os.path.join(root_dir, 'test')) + self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test') + else: + print("Use random train-test split.") + train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2) + + if is_train: + self.videos = train_videos + else: + self.videos = test_videos + + self.is_train = is_train + + if self.is_train: + self.transform = AllAugmentationTransform(**augmentation_params) + else: + self.transform = None + + def __len__(self): + return len(self.videos) + + def __getitem__(self, idx): + if self.is_train and self.id_sampling: + name = self.videos[idx] + path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4'))) + else: + name = self.videos[idx] + path = os.path.join(self.root_dir, name) + + video_name = os.path.basename(path) + + if self.is_train and os.path.isdir(path): + frames = os.listdir(path) + num_frames = len(frames) + frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) + video_array = [img_as_float32(io.imread(os.path.join(path, frames[idx]))) for idx in frame_idx] + else: + video_array = read_video(path, frame_shape=self.frame_shape) + num_frames = len(video_array) + frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) if self.is_train else range( + num_frames) + video_array = video_array[frame_idx] + + if self.transform is not None: + video_array = self.transform(video_array) + + out = {} + if self.is_train: + source = np.array(video_array[0], dtype='float32') + driving = np.array(video_array[1], dtype='float32') + + out['driving'] = driving.transpose((2, 0, 1)) + out['source'] = source.transpose((2, 0, 1)) + else: + video = np.array(video_array, dtype='float32') + out['video'] = video.transpose((3, 0, 1, 2)) + + out['name'] = video_name + + return out + + +class DatasetRepeater(Dataset): + """ + Pass several times over the same dataset for better i/o performance + """ + + def __init__(self, dataset, num_repeats=100): + self.dataset = dataset + self.num_repeats = num_repeats + + def __len__(self): + return self.num_repeats * self.dataset.__len__() + + def __getitem__(self, idx): + return self.dataset[idx % self.dataset.__len__()] diff --git a/face_vid2vid/logger.py b/face_vid2vid/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..c82005cb07700e4a2717a9758a36da8a6644f6a8 --- /dev/null +++ b/face_vid2vid/logger.py @@ -0,0 +1,202 @@ +import numpy as np +import torch +import torch.nn.functional as F +import imageio + +import os +from skimage.draw import circle_perimeter + +import matplotlib.pyplot as plt +import collections + + +class Logger: + def __init__(self, log_dir, checkpoint_freq=100, visualizer_params=None, zfill_num=8, log_file_name="log.txt"): + self.loss_list = [] + self.cpk_dir = log_dir + self.visualizations_dir = os.path.join(log_dir, "train-vis") + if not os.path.exists(self.visualizations_dir): + os.makedirs(self.visualizations_dir) + self.log_file = open(os.path.join(log_dir, log_file_name), "a") + self.zfill_num = zfill_num + self.visualizer = Visualizer(**visualizer_params) + self.checkpoint_freq = checkpoint_freq + self.epoch = 0 + self.best_loss = float("inf") + self.names = None + + def log_scores(self, loss_names): + loss_mean = np.array(self.loss_list).mean(axis=0) + + loss_string = "; ".join(["%s - %.5f" % (name, value) for name, value in zip(loss_names, loss_mean)]) + loss_string = str(self.epoch).zfill(self.zfill_num) + ") " + loss_string + + print(loss_string, file=self.log_file) + self.loss_list = [] + self.log_file.flush() + + def visualize_rec(self, inp, out): + image = self.visualizer.visualize(inp["driving"], inp["source"], out) + imageio.imsave(os.path.join(self.visualizations_dir, "%s-rec.png" % str(self.epoch).zfill(self.zfill_num)), image) + + def save_cpk(self, emergent=False): + cpk = {k: v.state_dict() for k, v in self.models.items()} + cpk["epoch"] = self.epoch + cpk_path = os.path.join(self.cpk_dir, "%s-checkpoint.pth.tar" % str(self.epoch).zfill(self.zfill_num)) + if not (os.path.exists(cpk_path) and emergent): + torch.save(cpk, cpk_path) + + @staticmethod + def load_cpk( + checkpoint_path, + generator=None, + discriminator=None, + kp_detector=None, + he_estimator=None, + optimizer_generator=None, + optimizer_discriminator=None, + optimizer_kp_detector=None, + optimizer_he_estimator=None, + ): + checkpoint = torch.load(checkpoint_path) + if generator is not None: + generator.load_state_dict(checkpoint["generator"]) + if kp_detector is not None: + kp_detector.load_state_dict(checkpoint["kp_detector"]) + if he_estimator is not None: + he_estimator.load_state_dict(checkpoint["he_estimator"]) + if discriminator is not None: + try: + discriminator.load_state_dict(checkpoint["discriminator"]) + except: + print("No discriminator in the state-dict. Dicriminator will be randomly initialized") + if optimizer_generator is not None: + optimizer_generator.load_state_dict(checkpoint["optimizer_generator"]) + if optimizer_discriminator is not None: + try: + optimizer_discriminator.load_state_dict(checkpoint["optimizer_discriminator"]) + except RuntimeError as e: + print("No discriminator optimizer in the state-dict. Optimizer will be not initialized") + if optimizer_kp_detector is not None: + optimizer_kp_detector.load_state_dict(checkpoint["optimizer_kp_detector"]) + if optimizer_he_estimator is not None: + optimizer_he_estimator.load_state_dict(checkpoint["optimizer_he_estimator"]) + + return checkpoint["epoch"] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if "models" in self.__dict__: + self.save_cpk() + self.log_file.close() + + def log_iter(self, losses): + losses = collections.OrderedDict(losses.items()) + if self.names is None: + self.names = list(losses.keys()) + self.loss_list.append(list(losses.values())) + + def log_epoch(self, epoch, models, inp, out): + self.epoch = epoch + self.models = models + if (self.epoch + 1) % self.checkpoint_freq == 0: + self.save_cpk() + self.log_scores(self.names) + self.visualize_rec(inp, out) + + +class Visualizer: + def __init__(self, kp_size=5, draw_border=False, colormap="gist_rainbow"): + self.kp_size = kp_size + self.draw_border = draw_border + self.colormap = plt.get_cmap(colormap) + + def draw_image_with_kp(self, image, kp_array): + image = np.copy(image) + spatial_size = np.array(image.shape[:2][::-1])[np.newaxis] + kp_array = spatial_size * (kp_array + 1) / 2 + num_kp = kp_array.shape[0] + for kp_ind, kp in enumerate(kp_array): + rr, cc = circle_perimeter(kp[1], kp[0], self.kp_size, shape=image.shape[:2]) + image[rr, cc] = np.array(self.colormap(kp_ind / num_kp))[:3] + return image + + def create_image_column_with_kp(self, images, kp): + image_array = np.array([self.draw_image_with_kp(v, k) for v, k in zip(images, kp)]) + return self.create_image_column(image_array) + + def create_image_column(self, images): + if self.draw_border: + images = np.copy(images) + images[:, :, [0, -1]] = (1, 1, 1) + images[:, :, [0, -1]] = (1, 1, 1) + return np.concatenate(list(images), axis=0) + + def create_image_grid(self, *args): + out = [] + for arg in args: + if type(arg) == tuple: + out.append(self.create_image_column_with_kp(arg[0], arg[1])) + else: + out.append(self.create_image_column(arg)) + return np.concatenate(out, axis=1) + + def visualize(self, driving, source, out): + images = [] + + # Source image with keypoints + source = source.data.cpu() + kp_source = out["kp_source"]["value"][:, :, :2].data.cpu().numpy() # 3d -> 2d + source = np.transpose(source, [0, 2, 3, 1]) + images.append((source, kp_source)) + + # Equivariance visualization + if "transformed_frame" in out: + transformed = out["transformed_frame"].data.cpu().numpy() + transformed = np.transpose(transformed, [0, 2, 3, 1]) + transformed_kp = out["transformed_kp"]["value"][:, :, :2].data.cpu().numpy() # 3d -> 2d + images.append((transformed, transformed_kp)) + + # Driving image with keypoints + kp_driving = out["kp_driving"]["value"][:, :, :2].data.cpu().numpy() # 3d -> 2d + driving = driving.data.cpu().numpy() + driving = np.transpose(driving, [0, 2, 3, 1]) + images.append((driving, kp_driving)) + + # Result + prediction = out["prediction"].data.cpu().numpy() + prediction = np.transpose(prediction, [0, 2, 3, 1]) + images.append(prediction) + + ## Occlusion map + if "occlusion_map" in out: + occlusion_map = out["occlusion_map"].data.cpu().repeat(1, 3, 1, 1) + occlusion_map = F.interpolate(occlusion_map, size=source.shape[1:3]).numpy() + occlusion_map = np.transpose(occlusion_map, [0, 2, 3, 1]) + images.append(occlusion_map) + + ## Mask + if "mask" in out: + for i in range(out["mask"].shape[1]): + mask = out["mask"][:, i : (i + 1)].data.cpu().sum(2).repeat(1, 3, 1, 1) # (n, 3, h, w) + # mask = F.softmax(mask.view(mask.shape[0], mask.shape[1], -1), dim=2).view(mask.shape) + mask = F.interpolate(mask, size=source.shape[1:3]).numpy() + mask = np.transpose(mask, [0, 2, 3, 1]) + + if i != 0: + color = np.array(self.colormap((i - 1) / (out["mask"].shape[1] - 1)))[:3] + else: + color = np.array((0, 0, 0)) + + color = color.reshape((1, 1, 1, 3)) + + if i != 0: + images.append(mask * color) + else: + images.append(mask) + + image = self.create_image_grid(*images) + image = (255 * image).astype(np.uint8) + return image diff --git a/face_vid2vid/modules/__pycache__/dense_motion.cpython-310.pyc b/face_vid2vid/modules/__pycache__/dense_motion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bc426418561f79896bed67a0548e904d929ab49 Binary files /dev/null and b/face_vid2vid/modules/__pycache__/dense_motion.cpython-310.pyc differ diff --git a/face_vid2vid/modules/__pycache__/generator.cpython-310.pyc b/face_vid2vid/modules/__pycache__/generator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74feabd648a646e8b5b42596e4fd71aea4017901 Binary files /dev/null and b/face_vid2vid/modules/__pycache__/generator.cpython-310.pyc differ diff --git a/face_vid2vid/modules/__pycache__/keypoint_detector.cpython-310.pyc b/face_vid2vid/modules/__pycache__/keypoint_detector.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4744b792ab6ecda529611cde26422e5d1e89b71a Binary files /dev/null and b/face_vid2vid/modules/__pycache__/keypoint_detector.cpython-310.pyc differ diff --git a/face_vid2vid/modules/__pycache__/util.cpython-310.pyc b/face_vid2vid/modules/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..262307fe40bfe15d9d18829e05d58f94ec403ad1 Binary files /dev/null and b/face_vid2vid/modules/__pycache__/util.cpython-310.pyc differ diff --git a/face_vid2vid/modules/dense_motion.py b/face_vid2vid/modules/dense_motion.py new file mode 100644 index 0000000000000000000000000000000000000000..01db27ae9b69de6c04d50d10dfd5d8a79574128e --- /dev/null +++ b/face_vid2vid/modules/dense_motion.py @@ -0,0 +1,128 @@ +from torch import nn +import torch.nn.functional as F +import torch +from face_vid2vid.modules.util import Hourglass, make_coordinate_grid, kp2gaussian + +from face_vid2vid.sync_batchnorm.batchnorm import SynchronizedBatchNorm3d as BatchNorm3d + + +class DenseMotionNetwork(nn.Module): + """ + Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving + """ + + def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress, + estimate_occlusion_map=False): + super(DenseMotionNetwork, self).__init__() + # self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(feature_channel+1), max_features=max_features, num_blocks=num_blocks) + self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks) + + self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3) + + self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1) + self.norm = BatchNorm3d(compress, affine=True) + + if estimate_occlusion_map: + # self.occlusion = nn.Conv2d(reshape_channel*reshape_depth, 1, kernel_size=7, padding=3) + self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3) + else: + self.occlusion = None + + self.num_kp = num_kp + + + def create_sparse_motions(self, feature, kp_driving, kp_source): + bs, _, d, h, w = feature.shape + identity_grid = make_coordinate_grid((d, h, w), type=kp_source['value'].type()) + identity_grid = identity_grid.view(1, 1, d, h, w, 3) + coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 1, 3) + + k = coordinate_grid.shape[1] + + # if 'jacobian' in kp_driving: + if 'jacobian' in kp_driving and kp_driving['jacobian'] is not None: + jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian'])) + jacobian = jacobian.unsqueeze(-3).unsqueeze(-3).unsqueeze(-3) + jacobian = jacobian.repeat(1, 1, d, h, w, 1, 1) + coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1)) + coordinate_grid = coordinate_grid.squeeze(-1) + ''' + if 'rot' in kp_driving: + rot_s = kp_source['rot'] + rot_d = kp_driving['rot'] + rot = torch.einsum('bij, bjk->bki', rot_s, torch.inverse(rot_d)) + rot = rot.unsqueeze(-3).unsqueeze(-3).unsqueeze(-3).unsqueeze(-3) + rot = rot.repeat(1, k, d, h, w, 1, 1) + # print(rot.shape) + coordinate_grid = torch.matmul(rot, coordinate_grid.unsqueeze(-1)) + coordinate_grid = coordinate_grid.squeeze(-1) + # print(coordinate_grid.shape) + ''' + driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3) + + #adding background feature + identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1) + sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) + + # sparse_motions = driving_to_source + + return sparse_motions + + def create_deformed_feature(self, feature, sparse_motions): + bs, _, d, h, w = feature.shape + feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w) + feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w) + sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3) + sparse_deformed = F.grid_sample(feature_repeat, sparse_motions) + sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w) + return sparse_deformed + + def create_heatmap_representations(self, feature, kp_driving, kp_source): + spatial_size = feature.shape[3:] + gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01) + gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01) + heatmap = gaussian_driving - gaussian_source + + # adding background feature + zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.type()) + heatmap = torch.cat([zeros, heatmap], dim=1) + heatmap = heatmap.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w) + return heatmap + + def forward(self, feature, kp_driving, kp_source): + bs, _, d, h, w = feature.shape + + feature = self.compress(feature) + feature = self.norm(feature) + feature = F.relu(feature) + + out_dict = dict() + sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source) + deformed_feature = self.create_deformed_feature(feature, sparse_motion) + + heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source) + + input = torch.cat([heatmap, deformed_feature], dim=2) + input = input.view(bs, -1, d, h, w) + + # input = deformed_feature.view(bs, -1, d, h, w) # (bs, num_kp+1 * c, d, h, w) + + prediction = self.hourglass(input) + + mask = self.mask(prediction) + mask = F.softmax(mask, dim=1) + out_dict['mask'] = mask + mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w) + sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w) + deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w) + deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3) + + out_dict['deformation'] = deformation + + if self.occlusion: + bs, c, d, h, w = prediction.shape + prediction = prediction.view(bs, -1, h, w) + occlusion_map = torch.sigmoid(self.occlusion(prediction)) + out_dict['occlusion_map'] = occlusion_map + + return out_dict diff --git a/face_vid2vid/modules/discriminator.py b/face_vid2vid/modules/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..0e785b28eedaf97673c35095c1ae420d5e03ba09 --- /dev/null +++ b/face_vid2vid/modules/discriminator.py @@ -0,0 +1,90 @@ +from torch import nn +import torch.nn.functional as F +from face_vid2vid.modules.util import kp2gaussian +import torch + + +class DownBlock2d(nn.Module): + """ + Simple block for processing video (encoder). + """ + + def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size) + + if sn: + self.conv = nn.utils.spectral_norm(self.conv) + + if norm: + self.norm = nn.InstanceNorm2d(out_features, affine=True) + else: + self.norm = None + self.pool = pool + + def forward(self, x): + out = x + out = self.conv(out) + if self.norm: + out = self.norm(out) + out = F.leaky_relu(out, 0.2) + if self.pool: + out = F.avg_pool2d(out, (2, 2)) + return out + + +class Discriminator(nn.Module): + """ + Discriminator similar to Pix2Pix + """ + + def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512, + sn=False, **kwargs): + super(Discriminator, self).__init__() + + down_blocks = [] + for i in range(num_blocks): + down_blocks.append( + DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)), + min(max_features, block_expansion * (2 ** (i + 1))), + norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn)) + + self.down_blocks = nn.ModuleList(down_blocks) + self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1) + if sn: + self.conv = nn.utils.spectral_norm(self.conv) + + def forward(self, x): + feature_maps = [] + out = x + + for down_block in self.down_blocks: + feature_maps.append(down_block(out)) + out = feature_maps[-1] + prediction_map = self.conv(out) + + return feature_maps, prediction_map + + +class MultiScaleDiscriminator(nn.Module): + """ + Multi-scale (scale) discriminator + """ + + def __init__(self, scales=(), **kwargs): + super(MultiScaleDiscriminator, self).__init__() + self.scales = scales + discs = {} + for scale in scales: + discs[str(scale).replace('.', '-')] = Discriminator(**kwargs) + self.discs = nn.ModuleDict(discs) + + def forward(self, x): + out_dict = {} + for scale, disc in self.discs.items(): + scale = str(scale).replace('-', '.') + key = 'prediction_' + scale + feature_maps, prediction_map = disc(x[key]) + out_dict['feature_maps_' + scale] = feature_maps + out_dict['prediction_map_' + scale] = prediction_map + return out_dict diff --git a/face_vid2vid/modules/generator.py b/face_vid2vid/modules/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..eae943c93d602064b25e480e3a13dcc6bc83359a --- /dev/null +++ b/face_vid2vid/modules/generator.py @@ -0,0 +1,255 @@ +import torch +from torch import nn +import torch.nn.functional as F +from face_vid2vid.modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d, ResBlock3d, SPADEResnetBlock +from face_vid2vid.modules.dense_motion import DenseMotionNetwork + + +class OcclusionAwareGenerator(nn.Module): + """ + Generator follows NVIDIA architecture. + """ + + def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth, + num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False): + super(OcclusionAwareGenerator, self).__init__() + + if dense_motion_params is not None: + self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel, + estimate_occlusion_map=estimate_occlusion_map, + **dense_motion_params) + else: + self.dense_motion_network = None + + self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(7, 7), padding=(3, 3)) + + down_blocks = [] + for i in range(num_down_blocks): + in_features = min(max_features, block_expansion * (2 ** i)) + out_features = min(max_features, block_expansion * (2 ** (i + 1))) + down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) + self.down_blocks = nn.ModuleList(down_blocks) + + self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1) + + self.reshape_channel = reshape_channel + self.reshape_depth = reshape_depth + + self.resblocks_3d = torch.nn.Sequential() + for i in range(num_resblocks): + self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1)) + + out_features = block_expansion * (2 ** (num_down_blocks)) + self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True) + self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1) + + self.resblocks_2d = torch.nn.Sequential() + for i in range(num_resblocks): + self.resblocks_2d.add_module('2dr' + str(i), ResBlock2d(out_features, kernel_size=3, padding=1)) + + up_blocks = [] + for i in range(num_down_blocks): + in_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i))) + out_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i - 1))) + up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) + self.up_blocks = nn.ModuleList(up_blocks) + + self.final = nn.Conv2d(block_expansion, image_channel, kernel_size=(7, 7), padding=(3, 3)) + self.estimate_occlusion_map = estimate_occlusion_map + self.image_channel = image_channel + + def deform_input(self, inp, deformation): + _, d_old, h_old, w_old, _ = deformation.shape + _, _, d, h, w = inp.shape + if d_old != d or h_old != h or w_old != w: + deformation = deformation.permute(0, 4, 1, 2, 3) + deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear') + deformation = deformation.permute(0, 2, 3, 4, 1) + return F.grid_sample(inp, deformation) + + def forward(self, source_image, kp_driving, kp_source): + # Encoding (downsampling) part + out = self.first(source_image) + for i in range(len(self.down_blocks)): + out = self.down_blocks[i](out) + out = self.second(out) + bs, c, h, w = out.shape + # print(out.shape) + feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w) + feature_3d = self.resblocks_3d(feature_3d) + + # Transforming feature representation according to deformation and occlusion + output_dict = {} + if self.dense_motion_network is not None: + dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving, + kp_source=kp_source) + output_dict['mask'] = dense_motion['mask'] + + if 'occlusion_map' in dense_motion: + occlusion_map = dense_motion['occlusion_map'] + output_dict['occlusion_map'] = occlusion_map + else: + occlusion_map = None + deformation = dense_motion['deformation'] + out = self.deform_input(feature_3d, deformation) + + bs, c, d, h, w = out.shape + out = out.view(bs, c*d, h, w) + out = self.third(out) + out = self.fourth(out) + + if occlusion_map is not None: + if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]: + occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear') + out = out * occlusion_map + + # output_dict["deformed"] = self.deform_input(source_image, deformation) # 3d deformation cannot deform 2d image + + # Decoding part + out = self.resblocks_2d(out) + for i in range(len(self.up_blocks)): + out = self.up_blocks[i](out) + out = self.final(out) + out = F.sigmoid(out) + + output_dict["prediction"] = out + + return output_dict + + +class SPADEDecoder(nn.Module): + def __init__(self): + super().__init__() + ic = 256 + oc = 64 + norm_G = 'spadespectralinstance' + label_nc = 256 + + self.fc = nn.Conv2d(ic, 2 * ic, 3, padding=1) + self.G_middle_0 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) + self.G_middle_1 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) + self.G_middle_2 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) + self.G_middle_3 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) + self.G_middle_4 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) + self.G_middle_5 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) + self.up_0 = SPADEResnetBlock(2 * ic, ic, norm_G, label_nc) + self.up_1 = SPADEResnetBlock(ic, oc, norm_G, label_nc) + self.conv_img = nn.Conv2d(oc, 3, 3, padding=1) + self.up = nn.Upsample(scale_factor=2) + + def forward(self, feature): + seg = feature + x = self.fc(feature) + x = self.G_middle_0(x, seg) + x = self.G_middle_1(x, seg) + x = self.G_middle_2(x, seg) + x = self.G_middle_3(x, seg) + x = self.G_middle_4(x, seg) + x = self.G_middle_5(x, seg) + x = self.up(x) + x = self.up_0(x, seg) # 256, 128, 128 + x = self.up(x) + x = self.up_1(x, seg) # 64, 256, 256 + + x = self.conv_img(F.leaky_relu(x, 2e-1)) + # x = torch.tanh(x) + x = F.sigmoid(x) + + return x + + +class OcclusionAwareSPADEGenerator(nn.Module): + + def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth, + num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False): + super(OcclusionAwareSPADEGenerator, self).__init__() + + if dense_motion_params is not None: + self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel, + estimate_occlusion_map=estimate_occlusion_map, + **dense_motion_params) + else: + self.dense_motion_network = None + + self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1)) + + down_blocks = [] + for i in range(num_down_blocks): + in_features = min(max_features, block_expansion * (2 ** i)) + out_features = min(max_features, block_expansion * (2 ** (i + 1))) + down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) + self.down_blocks = nn.ModuleList(down_blocks) + + self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1) + + self.reshape_channel = reshape_channel + self.reshape_depth = reshape_depth + + self.resblocks_3d = torch.nn.Sequential() + for i in range(num_resblocks): + self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1)) + + out_features = block_expansion * (2 ** (num_down_blocks)) + self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True) + self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1) + + self.estimate_occlusion_map = estimate_occlusion_map + self.image_channel = image_channel + + self.decoder = SPADEDecoder() + + def deform_input(self, inp, deformation): + _, d_old, h_old, w_old, _ = deformation.shape + _, _, d, h, w = inp.shape + if d_old != d or h_old != h or w_old != w: + deformation = deformation.permute(0, 4, 1, 2, 3) + deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear') + deformation = deformation.permute(0, 2, 3, 4, 1) + return F.grid_sample(inp, deformation) + + def forward(self, source_image, kp_driving, kp_source, fp16=False): + if fp16: + source_image = source_image.half() + kp_driving['value'] = kp_driving['value'].half() + kp_source['value'] = kp_source['value'].half() + # Encoding (downsampling) part + out = self.first(source_image) + for i in range(len(self.down_blocks)): + out = self.down_blocks[i](out) + out = self.second(out) + bs, c, h, w = out.shape + # print(out.shape) + feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w) + feature_3d = self.resblocks_3d(feature_3d) + + # Transforming feature representation according to deformation and occlusion + output_dict = {} + if self.dense_motion_network is not None: + dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving, + kp_source=kp_source) + output_dict['mask'] = dense_motion['mask'] + + if 'occlusion_map' in dense_motion: + occlusion_map = dense_motion['occlusion_map'] + output_dict['occlusion_map'] = occlusion_map + else: + occlusion_map = None + deformation = dense_motion['deformation'] + out = self.deform_input(feature_3d, deformation) + + bs, c, d, h, w = out.shape + out = out.view(bs, c*d, h, w) + out = self.third(out) + out = self.fourth(out) + + if occlusion_map is not None: + if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]: + occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear') + out = out * occlusion_map + + # Decoding part + out = self.decoder(out) + + output_dict["prediction"] = out + + return output_dict \ No newline at end of file diff --git a/face_vid2vid/modules/hopenet.py b/face_vid2vid/modules/hopenet.py new file mode 100644 index 0000000000000000000000000000000000000000..c9e0b74ca67ddf19664e62b811050821cddd84fe --- /dev/null +++ b/face_vid2vid/modules/hopenet.py @@ -0,0 +1,171 @@ +import torch +import torch.nn as nn +from torch.autograd import Variable +import math +import torch.nn.functional as F + +class Hopenet(nn.Module): + # Hopenet with 3 output layers for yaw, pitch and roll + # Predicts Euler angles by binning and regression with the expected value + def __init__(self, block, layers, num_bins): + self.inplanes = 64 + super(Hopenet, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7) + self.fc_yaw = nn.Linear(512 * block.expansion, num_bins) + self.fc_pitch = nn.Linear(512 * block.expansion, num_bins) + self.fc_roll = nn.Linear(512 * block.expansion, num_bins) + + # Vestigial layer from previous experiments + self.fc_finetune = nn.Linear(512 * block.expansion + 3, 3) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + pre_yaw = self.fc_yaw(x) + pre_pitch = self.fc_pitch(x) + pre_roll = self.fc_roll(x) + + return pre_yaw, pre_pitch, pre_roll + +class ResNet(nn.Module): + # ResNet for regression of 3 Euler angles. + def __init__(self, block, layers, num_classes=1000): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7) + self.fc_angles = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc_angles(x) + return x + +class AlexNet(nn.Module): + # AlexNet laid out as a Hopenet - classify Euler angles in bins and + # regress the expected value. + def __init__(self, num_bins): + super(AlexNet, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(64, 192, kernel_size=5, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(192, 384, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(384, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + ) + self.classifier = nn.Sequential( + nn.Dropout(), + nn.Linear(256 * 6 * 6, 4096), + nn.ReLU(inplace=True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(inplace=True), + ) + self.fc_yaw = nn.Linear(4096, num_bins) + self.fc_pitch = nn.Linear(4096, num_bins) + self.fc_roll = nn.Linear(4096, num_bins) + + def forward(self, x): + x = self.features(x) + x = x.view(x.size(0), 256 * 6 * 6) + x = self.classifier(x) + yaw = self.fc_yaw(x) + pitch = self.fc_pitch(x) + roll = self.fc_roll(x) + return yaw, pitch, roll diff --git a/face_vid2vid/modules/keypoint_detector.py b/face_vid2vid/modules/keypoint_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..78c17570502f83376ebe632557727e3233c2dae1 --- /dev/null +++ b/face_vid2vid/modules/keypoint_detector.py @@ -0,0 +1,178 @@ +from torch import nn +import torch +import torch.nn.functional as F + +from face_vid2vid.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d as BatchNorm2d +from face_vid2vid.modules.util import KPHourglass, make_coordinate_grid, AntiAliasInterpolation2d, ResBottleneck + + +class KPDetector(nn.Module): + """ + Detecting canonical keypoints. Return keypoint position and jacobian near each keypoint. + """ + + def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, reshape_channel, reshape_depth, + num_blocks, temperature, estimate_jacobian=False, scale_factor=1, single_jacobian_map=False): + super(KPDetector, self).__init__() + + self.predictor = KPHourglass(block_expansion, in_features=image_channel, + max_features=max_features, reshape_features=reshape_channel, reshape_depth=reshape_depth, num_blocks=num_blocks) + + # self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=7, padding=3) + self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=3, padding=1) + + if estimate_jacobian: + self.num_jacobian_maps = 1 if single_jacobian_map else num_kp + # self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=7, padding=3) + self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=3, padding=1) + ''' + initial as: + [[1 0 0] + [0 1 0] + [0 0 1]] + ''' + self.jacobian.weight.data.zero_() + self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float)) + else: + self.jacobian = None + + self.temperature = temperature + self.scale_factor = scale_factor + if self.scale_factor != 1: + self.down = AntiAliasInterpolation2d(image_channel, self.scale_factor) + + def gaussian2kp(self, heatmap): + """ + Extract the mean from a heatmap + """ + shape = heatmap.shape + heatmap = heatmap.unsqueeze(-1) + grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) + value = (heatmap * grid).sum(dim=(2, 3, 4)) + kp = {'value': value} + + return kp + + def forward(self, x): + if self.scale_factor != 1: + x = self.down(x) + + feature_map = self.predictor(x) + prediction = self.kp(feature_map) + + final_shape = prediction.shape + heatmap = prediction.view(final_shape[0], final_shape[1], -1) + heatmap = F.softmax(heatmap / self.temperature, dim=2) + heatmap = heatmap.view(*final_shape) + + out = self.gaussian2kp(heatmap) + + if self.jacobian is not None: + jacobian_map = self.jacobian(feature_map) + jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 9, final_shape[2], + final_shape[3], final_shape[4]) + heatmap = heatmap.unsqueeze(2) + + jacobian = heatmap * jacobian_map + jacobian = jacobian.view(final_shape[0], final_shape[1], 9, -1) + jacobian = jacobian.sum(dim=-1) + jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 3, 3) + out['jacobian'] = jacobian + + return out + + +class HEEstimator(nn.Module): + """ + Estimating head pose and expression. + """ + + def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, num_bins=66, estimate_jacobian=True): + super(HEEstimator, self).__init__() + + self.conv1 = nn.Conv2d(in_channels=image_channel, out_channels=block_expansion, kernel_size=7, padding=3, stride=2) + self.norm1 = BatchNorm2d(block_expansion, affine=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.conv2 = nn.Conv2d(in_channels=block_expansion, out_channels=256, kernel_size=1) + self.norm2 = BatchNorm2d(256, affine=True) + + self.block1 = nn.Sequential() + for i in range(3): + self.block1.add_module('b1_'+ str(i), ResBottleneck(in_features=256, stride=1)) + + self.conv3 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1) + self.norm3 = BatchNorm2d(512, affine=True) + self.block2 = ResBottleneck(in_features=512, stride=2) + + self.block3 = nn.Sequential() + for i in range(3): + self.block3.add_module('b3_'+ str(i), ResBottleneck(in_features=512, stride=1)) + + self.conv4 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1) + self.norm4 = BatchNorm2d(1024, affine=True) + self.block4 = ResBottleneck(in_features=1024, stride=2) + + self.block5 = nn.Sequential() + for i in range(5): + self.block5.add_module('b5_'+ str(i), ResBottleneck(in_features=1024, stride=1)) + + self.conv5 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=1) + self.norm5 = BatchNorm2d(2048, affine=True) + self.block6 = ResBottleneck(in_features=2048, stride=2) + + self.block7 = nn.Sequential() + for i in range(2): + self.block7.add_module('b7_'+ str(i), ResBottleneck(in_features=2048, stride=1)) + + self.fc_roll = nn.Linear(2048, num_bins) + self.fc_pitch = nn.Linear(2048, num_bins) + self.fc_yaw = nn.Linear(2048, num_bins) + + self.fc_t = nn.Linear(2048, 3) + + self.fc_exp = nn.Linear(2048, 3*num_kp) + + def forward(self, x): + out = self.conv1(x) + out = self.norm1(out) + out = F.relu(out) + out = self.maxpool(out) + + out = self.conv2(out) + out = self.norm2(out) + out = F.relu(out) + + out = self.block1(out) + + out = self.conv3(out) + out = self.norm3(out) + out = F.relu(out) + out = self.block2(out) + + out = self.block3(out) + + out = self.conv4(out) + out = self.norm4(out) + out = F.relu(out) + out = self.block4(out) + + out = self.block5(out) + + out = self.conv5(out) + out = self.norm5(out) + out = F.relu(out) + out = self.block6(out) + + out = self.block7(out) + + out = F.adaptive_avg_pool2d(out, 1) + out = out.view(out.shape[0], -1) + + yaw = self.fc_roll(out) + pitch = self.fc_pitch(out) + roll = self.fc_yaw(out) + t = self.fc_t(out) + exp = self.fc_exp(out) + + return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp} diff --git a/face_vid2vid/modules/model.py b/face_vid2vid/modules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..c7f6fcff2cc47d60986c103d9f6eda46da691d11 --- /dev/null +++ b/face_vid2vid/modules/model.py @@ -0,0 +1,446 @@ +from torch import nn +import torch +import torch.nn.functional as F +from face_vid2vid.modules.util import AntiAliasInterpolation2d, make_coordinate_grid_2d +from torchvision import models +import numpy as np +from torch.autograd import grad +import face_vid2vid.modules.hopenet as hopenet +from torchvision import transforms + + +class Vgg19(torch.nn.Module): + """ + Vgg19 network for perceptual loss. + """ + def __init__(self, requires_grad=False): + super(Vgg19, self).__init__() + vgg_pretrained_features = models.vgg19(pretrained=True).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(2): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(2, 7): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(7, 12): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 21): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(21, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + + self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))), + requires_grad=False) + self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))), + requires_grad=False) + + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + X = (X - self.mean) / self.std + h_relu1 = self.slice1(X) + h_relu2 = self.slice2(h_relu1) + h_relu3 = self.slice3(h_relu2) + h_relu4 = self.slice4(h_relu3) + h_relu5 = self.slice5(h_relu4) + out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] + return out + + +class ImagePyramide(torch.nn.Module): + """ + Create image pyramide for computing pyramide perceptual loss. + """ + def __init__(self, scales, num_channels): + super(ImagePyramide, self).__init__() + downs = {} + for scale in scales: + downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale) + self.downs = nn.ModuleDict(downs) + + def forward(self, x): + out_dict = {} + for scale, down_module in self.downs.items(): + out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x) + return out_dict + + +class Transform: + """ + Random tps transformation for equivariance constraints. + """ + def __init__(self, bs, **kwargs): + noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3])) + self.theta = noise + torch.eye(2, 3).view(1, 2, 3) + self.bs = bs + + if ('sigma_tps' in kwargs) and ('points_tps' in kwargs): + self.tps = True + self.control_points = make_coordinate_grid_2d((kwargs['points_tps'], kwargs['points_tps']), type=noise.type()) + self.control_points = self.control_points.unsqueeze(0) + self.control_params = torch.normal(mean=0, + std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2])) + else: + self.tps = False + + def transform_frame(self, frame): + grid = make_coordinate_grid_2d(frame.shape[2:], type=frame.type()).unsqueeze(0) + grid = grid.view(1, frame.shape[2] * frame.shape[3], 2) + grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2) + return F.grid_sample(frame, grid, padding_mode="reflection") + + def warp_coordinates(self, coordinates): + theta = self.theta.type(coordinates.type()) + theta = theta.unsqueeze(1) + transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:] + transformed = transformed.squeeze(-1) + + if self.tps: + control_points = self.control_points.type(coordinates.type()) + control_params = self.control_params.type(coordinates.type()) + distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2) + distances = torch.abs(distances).sum(-1) + + result = distances ** 2 + result = result * torch.log(distances + 1e-6) + result = result * control_params + result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1) + transformed = transformed + result + + return transformed + + def jacobian(self, coordinates): + new_coordinates = self.warp_coordinates(coordinates) + grad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True) + grad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True) + jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2) + return jacobian + + +def detach_kp(kp): + return {key: value.detach() for key, value in kp.items()} + + +def headpose_pred_to_degree(pred): + device = pred.device + idx_tensor = [idx for idx in range(66)] + idx_tensor = torch.FloatTensor(idx_tensor).to(device) + pred = F.softmax(pred) + degree = torch.sum(pred*idx_tensor, axis=1) * 3 - 99 + + return degree + +''' +# beta version +def get_rotation_matrix(yaw, pitch, roll): + yaw = yaw / 180 * 3.14 + pitch = pitch / 180 * 3.14 + roll = roll / 180 * 3.14 + + roll = roll.unsqueeze(1) + pitch = pitch.unsqueeze(1) + yaw = yaw.unsqueeze(1) + + roll_mat = torch.cat([torch.ones_like(roll), torch.zeros_like(roll), torch.zeros_like(roll), + torch.zeros_like(roll), torch.cos(roll), -torch.sin(roll), + torch.zeros_like(roll), torch.sin(roll), torch.cos(roll)], dim=1) + roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3) + + pitch_mat = torch.cat([torch.cos(pitch), torch.zeros_like(pitch), torch.sin(pitch), + torch.zeros_like(pitch), torch.ones_like(pitch), torch.zeros_like(pitch), + -torch.sin(pitch), torch.zeros_like(pitch), torch.cos(pitch)], dim=1) + pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3) + + yaw_mat = torch.cat([torch.cos(yaw), -torch.sin(yaw), torch.zeros_like(yaw), + torch.sin(yaw), torch.cos(yaw), torch.zeros_like(yaw), + torch.zeros_like(yaw), torch.zeros_like(yaw), torch.ones_like(yaw)], dim=1) + yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3) + + rot_mat = torch.einsum('bij,bjk,bkm->bim', roll_mat, pitch_mat, yaw_mat) + + return rot_mat +''' + +def get_rotation_matrix(yaw, pitch, roll): + yaw = yaw / 180 * 3.14 + pitch = pitch / 180 * 3.14 + roll = roll / 180 * 3.14 + + roll = roll.unsqueeze(1) + pitch = pitch.unsqueeze(1) + yaw = yaw.unsqueeze(1) + + pitch_mat = torch.cat([torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch), + torch.zeros_like(pitch), torch.cos(pitch), -torch.sin(pitch), + torch.zeros_like(pitch), torch.sin(pitch), torch.cos(pitch)], dim=1) + pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3) + + yaw_mat = torch.cat([torch.cos(yaw), torch.zeros_like(yaw), torch.sin(yaw), + torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw), + -torch.sin(yaw), torch.zeros_like(yaw), torch.cos(yaw)], dim=1) + yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3) + + roll_mat = torch.cat([torch.cos(roll), -torch.sin(roll), torch.zeros_like(roll), + torch.sin(roll), torch.cos(roll), torch.zeros_like(roll), + torch.zeros_like(roll), torch.zeros_like(roll), torch.ones_like(roll)], dim=1) + roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3) + + rot_mat = torch.einsum('bij,bjk,bkm->bim', pitch_mat, yaw_mat, roll_mat) + + return rot_mat + +def keypoint_transformation(kp_canonical, he, estimate_jacobian=True): + kp = kp_canonical['value'] # (bs, k, 3) + yaw, pitch, roll = he['yaw'], he['pitch'], he['roll'] + t, exp = he['t'], he['exp'] + + yaw = headpose_pred_to_degree(yaw) + pitch = headpose_pred_to_degree(pitch) + roll = headpose_pred_to_degree(roll) + + rot_mat = get_rotation_matrix(yaw, pitch, roll) # (bs, 3, 3) + + # keypoint rotation + kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp) + + # keypoint translation + t = t.unsqueeze_(1).repeat(1, kp.shape[1], 1) + kp_t = kp_rotated + t + + # add expression deviation + exp = exp.view(exp.shape[0], -1, 3) + kp_transformed = kp_t + exp + + if estimate_jacobian: + jacobian = kp_canonical['jacobian'] # (bs, k ,3, 3) + jacobian_transformed = torch.einsum('bmp,bkps->bkms', rot_mat, jacobian) + else: + jacobian_transformed = None + + return {'value': kp_transformed, 'jacobian': jacobian_transformed} + +class GeneratorFullModel(torch.nn.Module): + """ + Merge all generator related updates into single model for better multi-gpu usage + """ + + def __init__(self, kp_extractor, he_estimator, generator, discriminator, train_params, estimate_jacobian=True): + super(GeneratorFullModel, self).__init__() + self.kp_extractor = kp_extractor + self.he_estimator = he_estimator + self.generator = generator + self.discriminator = discriminator + self.train_params = train_params + self.scales = train_params['scales'] + self.disc_scales = self.discriminator.scales + self.pyramid = ImagePyramide(self.scales, generator.image_channel) + if torch.cuda.is_available(): + self.pyramid = self.pyramid.cuda() + + self.loss_weights = train_params['loss_weights'] + + self.estimate_jacobian = estimate_jacobian + + if sum(self.loss_weights['perceptual']) != 0: + self.vgg = Vgg19() + if torch.cuda.is_available(): + self.vgg = self.vgg.cuda() + + if self.loss_weights['headpose'] != 0: + self.hopenet = hopenet.Hopenet(models.resnet.Bottleneck, [3, 4, 6, 3], 66) + print('Loading hopenet') + hopenet_state_dict = torch.load(train_params['hopenet_snapshot']) + self.hopenet.load_state_dict(hopenet_state_dict) + if torch.cuda.is_available(): + self.hopenet = self.hopenet.cuda() + self.hopenet.eval() + + + def forward(self, x): + kp_canonical = self.kp_extractor(x['source']) # {'value': value, 'jacobian': jacobian} + + he_source = self.he_estimator(x['source']) # {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp} + he_driving = self.he_estimator(x['driving']) # {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp} + + # {'value': value, 'jacobian': jacobian} + kp_source = keypoint_transformation(kp_canonical, he_source, self.estimate_jacobian) + kp_driving = keypoint_transformation(kp_canonical, he_driving, self.estimate_jacobian) + + generated = self.generator(x['source'], kp_source=kp_source, kp_driving=kp_driving) + generated.update({'kp_source': kp_source, 'kp_driving': kp_driving}) + + loss_values = {} + + pyramide_real = self.pyramid(x['driving']) + pyramide_generated = self.pyramid(generated['prediction']) + + if sum(self.loss_weights['perceptual']) != 0: + value_total = 0 + for scale in self.scales: + x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)]) + y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)]) + + for i, weight in enumerate(self.loss_weights['perceptual']): + value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() + value_total += self.loss_weights['perceptual'][i] * value + loss_values['perceptual'] = value_total + + if self.loss_weights['generator_gan'] != 0: + discriminator_maps_generated = self.discriminator(pyramide_generated) + discriminator_maps_real = self.discriminator(pyramide_real) + value_total = 0 + for scale in self.disc_scales: + key = 'prediction_map_%s' % scale + if self.train_params['gan_mode'] == 'hinge': + value = -torch.mean(discriminator_maps_generated[key]) + elif self.train_params['gan_mode'] == 'ls': + value = ((1 - discriminator_maps_generated[key]) ** 2).mean() + else: + raise ValueError('Unexpected gan_mode {}'.format(self.train_params['gan_mode'])) + + value_total += self.loss_weights['generator_gan'] * value + loss_values['gen_gan'] = value_total + + if sum(self.loss_weights['feature_matching']) != 0: + value_total = 0 + for scale in self.disc_scales: + key = 'feature_maps_%s' % scale + for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])): + if self.loss_weights['feature_matching'][i] == 0: + continue + value = torch.abs(a - b).mean() + value_total += self.loss_weights['feature_matching'][i] * value + loss_values['feature_matching'] = value_total + + if (self.loss_weights['equivariance_value'] + self.loss_weights['equivariance_jacobian']) != 0: + transform = Transform(x['driving'].shape[0], **self.train_params['transform_params']) + transformed_frame = transform.transform_frame(x['driving']) + + transformed_he_driving = self.he_estimator(transformed_frame) + + transformed_kp = keypoint_transformation(kp_canonical, transformed_he_driving, self.estimate_jacobian) + + generated['transformed_frame'] = transformed_frame + generated['transformed_kp'] = transformed_kp + + ## Value loss part + if self.loss_weights['equivariance_value'] != 0: + # project 3d -> 2d + kp_driving_2d = kp_driving['value'][:, :, :2] + transformed_kp_2d = transformed_kp['value'][:, :, :2] + value = torch.abs(kp_driving_2d - transform.warp_coordinates(transformed_kp_2d)).mean() + loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value + + ## jacobian loss part + if self.loss_weights['equivariance_jacobian'] != 0: + # project 3d -> 2d + transformed_kp_2d = transformed_kp['value'][:, :, :2] + transformed_jacobian_2d = transformed_kp['jacobian'][:, :, :2, :2] + jacobian_transformed = torch.matmul(transform.jacobian(transformed_kp_2d), + transformed_jacobian_2d) + + jacobian_2d = kp_driving['jacobian'][:, :, :2, :2] + normed_driving = torch.inverse(jacobian_2d) + normed_transformed = jacobian_transformed + value = torch.matmul(normed_driving, normed_transformed) + + eye = torch.eye(2).view(1, 1, 2, 2).type(value.type()) + + value = torch.abs(eye - value).mean() + loss_values['equivariance_jacobian'] = self.loss_weights['equivariance_jacobian'] * value + + if self.loss_weights['keypoint'] != 0: + # print(kp_driving['value'].shape) # (bs, k, 3) + value_total = 0 + for i in range(kp_driving['value'].shape[1]): + for j in range(kp_driving['value'].shape[1]): + dist = F.pairwise_distance(kp_driving['value'][:, i, :], kp_driving['value'][:, j, :], p=2, keepdim=True) ** 2 + dist = 0.1 - dist # set Dt = 0.1 + dd = torch.gt(dist, 0) + value = (dist * dd).mean() + value_total += value + + kp_mean_depth = kp_driving['value'][:, :, -1].mean(-1) + value_depth = torch.abs(kp_mean_depth - 0.33).mean() # set Zt = 0.33 + + value_total += value_depth + loss_values['keypoint'] = self.loss_weights['keypoint'] * value_total + + if self.loss_weights['headpose'] != 0: + transform_hopenet = transforms.Compose([transforms.Resize(size=(224, 224)), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) + driving_224 = transform_hopenet(x['driving']) + + yaw_gt, pitch_gt, roll_gt = self.hopenet(driving_224) + yaw_gt = headpose_pred_to_degree(yaw_gt) + pitch_gt = headpose_pred_to_degree(pitch_gt) + roll_gt = headpose_pred_to_degree(roll_gt) + + yaw, pitch, roll = he_driving['yaw'], he_driving['pitch'], he_driving['roll'] + yaw = headpose_pred_to_degree(yaw) + pitch = headpose_pred_to_degree(pitch) + roll = headpose_pred_to_degree(roll) + + value = torch.abs(yaw - yaw_gt).mean() + torch.abs(pitch - pitch_gt).mean() + torch.abs(roll - roll_gt).mean() + loss_values['headpose'] = self.loss_weights['headpose'] * value + + if self.loss_weights['expression'] != 0: + value = torch.norm(he_driving['exp'], p=1, dim=-1).mean() + loss_values['expression'] = self.loss_weights['expression'] * value + + return loss_values, generated + + +class DiscriminatorFullModel(torch.nn.Module): + """ + Merge all discriminator related updates into single model for better multi-gpu usage + """ + + def __init__(self, kp_extractor, generator, discriminator, train_params): + super(DiscriminatorFullModel, self).__init__() + self.kp_extractor = kp_extractor + self.generator = generator + self.discriminator = discriminator + self.train_params = train_params + self.scales = self.discriminator.scales + self.pyramid = ImagePyramide(self.scales, generator.image_channel) + if torch.cuda.is_available(): + self.pyramid = self.pyramid.cuda() + + self.loss_weights = train_params['loss_weights'] + + self.zero_tensor = None + + def get_zero_tensor(self, input): + if self.zero_tensor is None: + self.zero_tensor = torch.FloatTensor(1).fill_(0).cuda() + self.zero_tensor.requires_grad_(False) + return self.zero_tensor.expand_as(input) + + def forward(self, x, generated): + pyramide_real = self.pyramid(x['driving']) + pyramide_generated = self.pyramid(generated['prediction'].detach()) + + discriminator_maps_generated = self.discriminator(pyramide_generated) + discriminator_maps_real = self.discriminator(pyramide_real) + + loss_values = {} + value_total = 0 + for scale in self.scales: + key = 'prediction_map_%s' % scale + if self.train_params['gan_mode'] == 'hinge': + value = -torch.mean(torch.min(discriminator_maps_real[key]-1, self.get_zero_tensor(discriminator_maps_real[key]))) - torch.mean(torch.min(-discriminator_maps_generated[key]-1, self.get_zero_tensor(discriminator_maps_generated[key]))) + elif self.train_params['gan_mode'] == 'ls': + value = ((1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2).mean() + else: + raise ValueError('Unexpected gan_mode {}'.format(self.train_params['gan_mode'])) + + value_total += self.loss_weights['discriminator_gan'] * value + loss_values['disc_gan'] = value_total + + return loss_values diff --git a/face_vid2vid/modules/util.py b/face_vid2vid/modules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..51dd72c4ce77cbdf59cfc29c8af5647930935d0a --- /dev/null +++ b/face_vid2vid/modules/util.py @@ -0,0 +1,483 @@ +from torch import nn + +import torch.nn.functional as F +import torch + +from face_vid2vid.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d as BatchNorm2d +from face_vid2vid.sync_batchnorm.batchnorm import SynchronizedBatchNorm3d as BatchNorm3d + +import torch.nn.utils.spectral_norm as spectral_norm +import re + + +def kp2gaussian(kp, spatial_size, kp_variance): + """ + Transform a keypoint into gaussian like representation + """ + mean = kp['value'] + + coordinate_grid = make_coordinate_grid(spatial_size, mean.type()) + number_of_leading_dimensions = len(mean.shape) - 1 + shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape + coordinate_grid = coordinate_grid.view(*shape) + repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1) + coordinate_grid = coordinate_grid.repeat(*repeats) + + # Preprocess kp shape + shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3) + mean = mean.view(*shape) + + mean_sub = (coordinate_grid - mean) + + out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance) + + return out + +def make_coordinate_grid_2d(spatial_size, type): + """ + Create a meshgrid [-1,1] x [-1,1] of given spatial_size. + """ + h, w = spatial_size + x = torch.arange(w).type(type) + y = torch.arange(h).type(type) + + x = (2 * (x / (w - 1)) - 1) + y = (2 * (y / (h - 1)) - 1) + + yy = y.view(-1, 1).repeat(1, w) + xx = x.view(1, -1).repeat(h, 1) + + meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) + + return meshed + + +def make_coordinate_grid(spatial_size, type): + d, h, w = spatial_size + x = torch.arange(w).type(type) + y = torch.arange(h).type(type) + z = torch.arange(d).type(type) + + x = (2 * (x / (w - 1)) - 1) + y = (2 * (y / (h - 1)) - 1) + z = (2 * (z / (d - 1)) - 1) + + yy = y.view(1, -1, 1).repeat(d, 1, w) + xx = x.view(1, 1, -1).repeat(d, h, 1) + zz = z.view(-1, 1, 1).repeat(1, h, w) + + meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3) + + return meshed + + +class ResBottleneck(nn.Module): + def __init__(self, in_features, stride): + super(ResBottleneck, self).__init__() + self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features//4, kernel_size=1) + self.conv2 = nn.Conv2d(in_channels=in_features//4, out_channels=in_features//4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(in_channels=in_features//4, out_channels=in_features, kernel_size=1) + self.norm1 = BatchNorm2d(in_features//4, affine=True) + self.norm2 = BatchNorm2d(in_features//4, affine=True) + self.norm3 = BatchNorm2d(in_features, affine=True) + + self.stride = stride + if self.stride != 1: + self.skip = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=1, stride=stride) + self.norm4 = BatchNorm2d(in_features, affine=True) + + def forward(self, x): + out = self.conv1(x) + out = self.norm1(out) + out = F.relu(out) + out = self.conv2(out) + out = self.norm2(out) + out = F.relu(out) + out = self.conv3(out) + out = self.norm3(out) + if self.stride != 1: + x = self.skip(x) + x = self.norm4(x) + out += x + out = F.relu(out) + return out + + +class ResBlock2d(nn.Module): + """ + Res block, preserve spatial resolution. + """ + + def __init__(self, in_features, kernel_size, padding): + super(ResBlock2d, self).__init__() + self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.norm1 = BatchNorm2d(in_features, affine=True) + self.norm2 = BatchNorm2d(in_features, affine=True) + + def forward(self, x): + out = self.norm1(x) + out = F.relu(out) + out = self.conv1(out) + out = self.norm2(out) + out = F.relu(out) + out = self.conv2(out) + out += x + return out + + +class ResBlock3d(nn.Module): + """ + Res block, preserve spatial resolution. + """ + + def __init__(self, in_features, kernel_size, padding): + super(ResBlock3d, self).__init__() + self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.norm1 = BatchNorm3d(in_features, affine=True) + self.norm2 = BatchNorm3d(in_features, affine=True) + + def forward(self, x): + out = self.norm1(x) + out = F.relu(out) + out = self.conv1(out) + out = self.norm2(out) + out = F.relu(out) + out = self.conv2(out) + out += x + return out + + +class UpBlock2d(nn.Module): + """ + Upsampling block for use in decoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(UpBlock2d, self).__init__() + + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = BatchNorm2d(out_features, affine=True) + + def forward(self, x): + out = F.interpolate(x, scale_factor=2) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + +class UpBlock3d(nn.Module): + """ + Upsampling block for use in decoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(UpBlock3d, self).__init__() + + self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = BatchNorm3d(out_features, affine=True) + + def forward(self, x): + # out = F.interpolate(x, scale_factor=(1, 2, 2), mode='trilinear') + out = F.interpolate(x, scale_factor=(1, 2, 2)) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + + +class DownBlock2d(nn.Module): + """ + Downsampling block for use in encoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = BatchNorm2d(out_features, affine=True) + self.pool = nn.AvgPool2d(kernel_size=(2, 2)) + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = F.relu(out) + out = self.pool(out) + return out + + +class DownBlock3d(nn.Module): + """ + Downsampling block for use in encoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(DownBlock3d, self).__init__() + ''' + self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups, stride=(1, 2, 2)) + ''' + self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = BatchNorm3d(out_features, affine=True) + self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2)) + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = F.relu(out) + out = self.pool(out) + return out + + +class SameBlock2d(nn.Module): + """ + Simple block, preserve spatial resolution. + """ + + def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False): + super(SameBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, + kernel_size=kernel_size, padding=padding, groups=groups) + self.norm = BatchNorm2d(out_features, affine=True) + if lrelu: + self.ac = nn.LeakyReLU() + else: + self.ac = nn.ReLU() + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = self.ac(out) + return out + + +class Encoder(nn.Module): + """ + Hourglass Encoder + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Encoder, self).__init__() + + down_blocks = [] + for i in range(num_blocks): + down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), + min(max_features, block_expansion * (2 ** (i + 1))), + kernel_size=3, padding=1)) + self.down_blocks = nn.ModuleList(down_blocks) + + def forward(self, x): + outs = [x] + for down_block in self.down_blocks: + outs.append(down_block(outs[-1])) + return outs + + +class Decoder(nn.Module): + """ + Hourglass Decoder + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Decoder, self).__init__() + + up_blocks = [] + + for i in range(num_blocks)[::-1]: + in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1))) + out_filters = min(max_features, block_expansion * (2 ** i)) + up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1)) + + self.up_blocks = nn.ModuleList(up_blocks) + # self.out_filters = block_expansion + self.out_filters = block_expansion + in_features + + self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1) + self.norm = BatchNorm3d(self.out_filters, affine=True) + + def forward(self, x): + out = x.pop() + # for up_block in self.up_blocks[:-1]: + for up_block in self.up_blocks: + out = up_block(out) + skip = x.pop() + out = torch.cat([out, skip], dim=1) + # out = self.up_blocks[-1](out) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + + +class Hourglass(nn.Module): + """ + Hourglass architecture. + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Hourglass, self).__init__() + self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features) + self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features) + self.out_filters = self.decoder.out_filters + + def forward(self, x): + return self.decoder(self.encoder(x)) + + +class KPHourglass(nn.Module): + """ + Hourglass architecture. + """ + + def __init__(self, block_expansion, in_features, reshape_features, reshape_depth, num_blocks=3, max_features=256): + super(KPHourglass, self).__init__() + + self.down_blocks = nn.Sequential() + for i in range(num_blocks): + self.down_blocks.add_module('down'+ str(i), DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), + min(max_features, block_expansion * (2 ** (i + 1))), + kernel_size=3, padding=1)) + + in_filters = min(max_features, block_expansion * (2 ** num_blocks)) + self.conv = nn.Conv2d(in_channels=in_filters, out_channels=reshape_features, kernel_size=1) + + self.up_blocks = nn.Sequential() + for i in range(num_blocks): + in_filters = min(max_features, block_expansion * (2 ** (num_blocks - i))) + out_filters = min(max_features, block_expansion * (2 ** (num_blocks - i - 1))) + self.up_blocks.add_module('up'+ str(i), UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1)) + + self.reshape_depth = reshape_depth + self.out_filters = out_filters + + def forward(self, x): + out = self.down_blocks(x) + out = self.conv(out) + bs, c, h, w = out.shape + out = out.view(bs, c//self.reshape_depth, self.reshape_depth, h, w) + out = self.up_blocks(out) + + return out + + + +class AntiAliasInterpolation2d(nn.Module): + """ + Band-limited downsampling, for better preservation of the input signal. + """ + def __init__(self, channels, scale): + super(AntiAliasInterpolation2d, self).__init__() + sigma = (1 / scale - 1) / 2 + kernel_size = 2 * round(sigma * 4) + 1 + self.ka = kernel_size // 2 + self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka + + kernel_size = [kernel_size, kernel_size] + sigma = [sigma, sigma] + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid( + [ + torch.arange(size, dtype=torch.float32) + for size in kernel_size + ] + ) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2)) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + self.register_buffer('weight', kernel) + self.groups = channels + self.scale = scale + inv_scale = 1 / scale + self.int_inv_scale = int(inv_scale) + + def forward(self, input): + if self.scale == 1.0: + return input + + out = F.pad(input, (self.ka, self.kb, self.ka, self.kb)) + out = F.conv2d(out, weight=self.weight, groups=self.groups) + out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale] + + return out + + +class SPADE(nn.Module): + def __init__(self, norm_nc, label_nc): + super().__init__() + + self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) + nhidden = 128 + + self.mlp_shared = nn.Sequential( + nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1), + nn.ReLU()) + self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) + self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) + + def forward(self, x, segmap): + normalized = self.param_free_norm(x) + segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') + actv = self.mlp_shared(segmap) + gamma = self.mlp_gamma(actv) + beta = self.mlp_beta(actv) + out = normalized * (1 + gamma) + beta + return out + + +class SPADEResnetBlock(nn.Module): + def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1): + super().__init__() + # Attributes + self.learned_shortcut = (fin != fout) + fmiddle = min(fin, fout) + self.use_se = use_se + # create conv layers + self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation) + self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation) + if self.learned_shortcut: + self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) + # apply spectral norm if specified + if 'spectral' in norm_G: + self.conv_0 = spectral_norm(self.conv_0) + self.conv_1 = spectral_norm(self.conv_1) + if self.learned_shortcut: + self.conv_s = spectral_norm(self.conv_s) + # define normalization layers + self.norm_0 = SPADE(fin, label_nc) + self.norm_1 = SPADE(fmiddle, label_nc) + if self.learned_shortcut: + self.norm_s = SPADE(fin, label_nc) + + def forward(self, x, seg1): + x_s = self.shortcut(x, seg1) + dx = self.conv_0(self.actvn(self.norm_0(x, seg1))) + dx = self.conv_1(self.actvn(self.norm_1(dx, seg1))) + out = x_s + dx + return out + + def shortcut(self, x, seg1): + if self.learned_shortcut: + x_s = self.conv_s(self.norm_s(x, seg1)) + else: + x_s = x + return x_s + + def actvn(self, x): + return F.leaky_relu(x, 2e-1) \ No newline at end of file diff --git a/face_vid2vid/run.py b/face_vid2vid/run.py new file mode 100644 index 0000000000000000000000000000000000000000..a4b4be2520fb510ee6b51cbe40ed6354c6d72bfd --- /dev/null +++ b/face_vid2vid/run.py @@ -0,0 +1,95 @@ +import matplotlib + +matplotlib.use("Agg") + +import os, sys +import yaml +from argparse import ArgumentParser +from time import gmtime, strftime +from shutil import copy + +from frames_dataset import FramesDataset + +from modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator +from modules.discriminator import MultiScaleDiscriminator +from modules.keypoint_detector import KPDetector, HEEstimator + +import torch + +from train import train + +if __name__ == "__main__": + if sys.version_info[0] < 3: + raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7") + + parser = ArgumentParser() + parser.add_argument("--config", default="config/vox-256.yaml", help="path to config") + parser.add_argument( + "--mode", + default="train", + choices=[ + "train", + ], + ) + parser.add_argument("--gen", default="original", choices=["original", "spade"]) + parser.add_argument("--log_dir", default="log", help="path to log into") + parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore") + parser.add_argument( + "--device_ids", + default="0, 1, 2, 3, 4, 5, 6, 7", + type=lambda x: list(map(int, x.split(","))), + help="Names of the devices comma separated.", + ) + parser.add_argument("--verbose", dest="verbose", action="store_true", help="Print model architecture") + parser.set_defaults(verbose=False) + + opt = parser.parse_args() + with open(opt.config) as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + if opt.checkpoint is not None: + log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1]) + else: + log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split(".")[0]) + log_dir += " " + strftime("%d_%m_%y_%H.%M.%S", gmtime()) + + if opt.gen == "original": + generator = OcclusionAwareGenerator(**config["model_params"]["generator_params"], **config["model_params"]["common_params"]) + elif opt.gen == "spade": + generator = OcclusionAwareSPADEGenerator(**config["model_params"]["generator_params"], **config["model_params"]["common_params"]) + + if torch.cuda.is_available(): + print("cuda is available") + generator.to(opt.device_ids[0]) + if opt.verbose: + print(generator) + + discriminator = MultiScaleDiscriminator(**config["model_params"]["discriminator_params"], **config["model_params"]["common_params"]) + if torch.cuda.is_available(): + discriminator.to(opt.device_ids[0]) + if opt.verbose: + print(discriminator) + + kp_detector = KPDetector(**config["model_params"]["kp_detector_params"], **config["model_params"]["common_params"]) + + if torch.cuda.is_available(): + kp_detector.to(opt.device_ids[0]) + + if opt.verbose: + print(kp_detector) + + he_estimator = HEEstimator(**config["model_params"]["he_estimator_params"], **config["model_params"]["common_params"]) + + if torch.cuda.is_available(): + he_estimator.to(opt.device_ids[0]) + + dataset = FramesDataset(is_train=(opt.mode == "train"), **config["dataset_params"]) + + if not os.path.exists(log_dir): + os.makedirs(log_dir) + if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))): + copy(opt.config, log_dir) + + if opt.mode == "train": + print("Training...") + train(config, generator, discriminator, kp_detector, he_estimator, opt.checkpoint, log_dir, dataset, opt.device_ids) diff --git a/face_vid2vid/sync_batchnorm/__pycache__/batchnorm.cpython-310.pyc b/face_vid2vid/sync_batchnorm/__pycache__/batchnorm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41caed99d8a5f6e4cbd7f674e2ef3a7d56d83c00 Binary files /dev/null and b/face_vid2vid/sync_batchnorm/__pycache__/batchnorm.cpython-310.pyc differ diff --git a/face_vid2vid/sync_batchnorm/__pycache__/comm.cpython-310.pyc b/face_vid2vid/sync_batchnorm/__pycache__/comm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c97d3ad8404224f27e54d38d91ee8ac6460158f Binary files /dev/null and b/face_vid2vid/sync_batchnorm/__pycache__/comm.cpython-310.pyc differ diff --git a/face_vid2vid/sync_batchnorm/__pycache__/replicate.cpython-310.pyc b/face_vid2vid/sync_batchnorm/__pycache__/replicate.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17de4f85c5449f40f6c7649bf279e4269d72a06e Binary files /dev/null and b/face_vid2vid/sync_batchnorm/__pycache__/replicate.cpython-310.pyc differ diff --git a/face_vid2vid/sync_batchnorm/batchnorm.py b/face_vid2vid/sync_batchnorm/batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..5f4e763f0366dffa10320116413f8c7181a8aeb1 --- /dev/null +++ b/face_vid2vid/sync_batchnorm/batchnorm.py @@ -0,0 +1,315 @@ +# -*- coding: utf-8 -*- +# File : batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import collections + +import torch +import torch.nn.functional as F + +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast + +from .comm import SyncMaster + +__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dementions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) + + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): + super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + def forward(self, input): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + + # Resize the input to (B, C, -1). + input_shape = input.size() + input = input.view(input.size(0), self.num_features, -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input ** 2) + + # Reduce-and-broadcast the statistics. + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + + # Compute the output. + if self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + + # Always using same "device order" makes the ReduceAdd operation faster. + # Thanks to:: Tete Xiao (http://tetexiao.com/) + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) + + return outputs + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + + return mean, bias_var.clamp(self.eps) ** -0.5 + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm1d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm2d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm3d, self)._check_input_dim(input) diff --git a/face_vid2vid/sync_batchnorm/comm.py b/face_vid2vid/sync_batchnorm/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..922f8c4a3adaa9b32fdcaef09583be03b0d7eb2b --- /dev/null +++ b/face_vid2vid/sync_batchnorm/comm.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +# File : comm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def __getstate__(self): + return {'master_callback': self._master_callback} + + def __setstate__(self, state): + self.__init__(state['master_callback']) + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/face_vid2vid/sync_batchnorm/replicate.py b/face_vid2vid/sync_batchnorm/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..b71c7b8ed51a1d6c55b1f753bdd8d90bad79bd06 --- /dev/null +++ b/face_vid2vid/sync_batchnorm/replicate.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# File : replicate.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import functools + +from torch.nn.parallel.data_parallel import DataParallel + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate diff --git a/face_vid2vid/sync_batchnorm/unittest.py b/face_vid2vid/sync_batchnorm/unittest.py new file mode 100644 index 0000000000000000000000000000000000000000..0675c022e4ba85d38d1f813490f6740150909524 --- /dev/null +++ b/face_vid2vid/sync_batchnorm/unittest.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# File : unittest.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import unittest + +import numpy as np +from torch.autograd import Variable + + +def as_numpy(v): + if isinstance(v, Variable): + v = v.data + return v.cpu().numpy() + + +class TorchTestCase(unittest.TestCase): + def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): + npa, npb = as_numpy(a), as_numpy(b) + self.assertTrue( + np.allclose(npa, npb, atol=atol), + 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) + ) diff --git a/face_vid2vid/train.py b/face_vid2vid/train.py new file mode 100644 index 0000000000000000000000000000000000000000..b78e1c6bbc50996449f46291c9ee643ec2d4a1ac --- /dev/null +++ b/face_vid2vid/train.py @@ -0,0 +1,117 @@ +from tqdm import trange +import torch + +from torch.utils.data import DataLoader + +from logger import Logger +from modules.model import GeneratorFullModel, DiscriminatorFullModel + +from torch.optim.lr_scheduler import MultiStepLR + +from sync_batchnorm import DataParallelWithCallback + +from frames_dataset import DatasetRepeater + + +def train(config, generator, discriminator, kp_detector, he_estimator, checkpoint, log_dir, dataset, device_ids): + train_params = config["train_params"] + + optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params["lr_generator"], betas=(0.5, 0.999)) + optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=train_params["lr_discriminator"], betas=(0.5, 0.999)) + optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(), lr=train_params["lr_kp_detector"], betas=(0.5, 0.999)) + optimizer_he_estimator = torch.optim.Adam(he_estimator.parameters(), lr=train_params["lr_he_estimator"], betas=(0.5, 0.999)) + + if checkpoint is not None: + start_epoch = Logger.load_cpk( + checkpoint, + generator, + discriminator, + kp_detector, + he_estimator, + optimizer_generator, + optimizer_discriminator, + optimizer_kp_detector, + optimizer_he_estimator, + ) + else: + start_epoch = 0 + + scheduler_generator = MultiStepLR(optimizer_generator, train_params["epoch_milestones"], gamma=0.1, last_epoch=start_epoch - 1) + scheduler_discriminator = MultiStepLR(optimizer_discriminator, train_params["epoch_milestones"], gamma=0.1, last_epoch=start_epoch - 1) + scheduler_kp_detector = MultiStepLR( + optimizer_kp_detector, train_params["epoch_milestones"], gamma=0.1, last_epoch=-1 + start_epoch * (train_params["lr_kp_detector"] != 0) + ) + scheduler_he_estimator = MultiStepLR( + optimizer_he_estimator, train_params["epoch_milestones"], gamma=0.1, last_epoch=-1 + start_epoch * (train_params["lr_kp_detector"] != 0) + ) + + if "num_repeats" in train_params or train_params["num_repeats"] != 1: + dataset = DatasetRepeater(dataset, train_params["num_repeats"]) + dataloader = DataLoader(dataset, batch_size=train_params["batch_size"], shuffle=True, num_workers=16, drop_last=True) + + generator_full = GeneratorFullModel( + kp_detector, + he_estimator, + generator, + discriminator, + train_params, + estimate_jacobian=config["model_params"]["common_params"]["estimate_jacobian"], + ) + discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params) + + if torch.cuda.is_available(): + generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids) + discriminator_full = DataParallelWithCallback(discriminator_full, device_ids=device_ids) + + with Logger(log_dir=log_dir, visualizer_params=config["visualizer_params"], checkpoint_freq=train_params["checkpoint_freq"]) as logger: + for epoch in trange(start_epoch, train_params["num_epochs"]): + for x in dataloader: + losses_generator, generated = generator_full(x) + + loss_values = [val.mean() for val in losses_generator.values()] + loss = sum(loss_values) + + loss.backward() + optimizer_generator.step() + optimizer_generator.zero_grad() + optimizer_kp_detector.step() + optimizer_kp_detector.zero_grad() + optimizer_he_estimator.step() + optimizer_he_estimator.zero_grad() + + if train_params["loss_weights"]["generator_gan"] != 0: + optimizer_discriminator.zero_grad() + losses_discriminator = discriminator_full(x, generated) + loss_values = [val.mean() for val in losses_discriminator.values()] + loss = sum(loss_values) + + loss.backward() + optimizer_discriminator.step() + optimizer_discriminator.zero_grad() + else: + losses_discriminator = {} + + losses_generator.update(losses_discriminator) + losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()} + logger.log_iter(losses=losses) + + scheduler_generator.step() + scheduler_discriminator.step() + scheduler_kp_detector.step() + scheduler_he_estimator.step() + + logger.log_epoch( + epoch, + { + "generator": generator, + "discriminator": discriminator, + "kp_detector": kp_detector, + "he_estimator": he_estimator, + "optimizer_generator": optimizer_generator, + "optimizer_discriminator": optimizer_discriminator, + "optimizer_kp_detector": optimizer_kp_detector, + "optimizer_he_estimator": optimizer_he_estimator, + }, + inp=x, + out=generated, + ) diff --git a/image.jpeg b/image.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..7ac31bc3d8241ef5adcde31d61944e5d23dc5df0 Binary files /dev/null and b/image.jpeg differ diff --git a/image.py b/image.py new file mode 100644 index 0000000000000000000000000000000000000000..299761ef01f6bda2a2a8cffe0e467ea7a64798b6 --- /dev/null +++ b/image.py @@ -0,0 +1,35 @@ +from config import * +from diffusers import AutoPipelineForText2Image +from argparse import ArgumentParser +import humanize +import datetime as dt + +def generate_image(path_id, imgfile, prompt): + pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo") + image = pipe(prompt=prompt, num_inference_steps=4, guidance_scale=0.0).images[0] + image.save(os.path.join("temp", path_id, imgfile)) + +def generate_images(path_id, imgfile, prompt, times=1): + pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo") + for i in range(times): + image = pipe(prompt=prompt, num_inference_steps=1, guidance_scale=0.0).images[0] + image.save(os.path.join("temp", path_id, str(i) + "_" + imgfile)) + +if __name__ == '__main__': + path_id = str(int(time.time())) + path = os.path.join("temp", "image", path_id) + os.makedirs(path, exist_ok=True) + + parser = ArgumentParser() + parser.add_argument("--prompt", default="Young woman with long, blonde hair, smiling slightly", + help="avatar prompt") + parser.add_argument("--times", type=int, default=1, help="number of avatars to generate") + args = parser.parse_args() + + tstart = time.time() + + generate_images(os.path.join("image", path_id), "avatar.png", + f"hyper-realistic digital avatar, centered, {args.prompt}, \ + rim lighting, studio lighting, looking at the camera", args.times) + + print("total time:", humanize.naturaldelta(dt.timedelta(seconds=int(time.time() - tstart)))) \ No newline at end of file diff --git a/improve.py b/improve.py new file mode 100644 index 0000000000000000000000000000000000000000..17d434ec853376423491e2ae9d19b310ade0a077 --- /dev/null +++ b/improve.py @@ -0,0 +1,141 @@ +import time +from config import * +import cv2 +import glob +import numpy as np +import os +from basicsr.utils import imwrite +from pathos.pools import ParallelPool +import subprocess +import platform +from mutagen.wave import WAVE +import tqdm +from p_tqdm import * +import torch +from PIL import Image +from RealESRGAN import RealESRGAN + +def vid2frames(vidPath, framesOutPath): + print(vidPath) + print(framesOutPath) + vidcap = cv2.VideoCapture(vidPath) + success,image = vidcap.read() + frame = 1 + while success: + cv2.imwrite(os.path.join(framesOutPath, str(frame).zfill(5) + '.png'), image) + success,image = vidcap.read() + frame += 1 + +def restore_frames(audiofilePath, videoOutPath, improveOutputPath): + no_of_frames = count_files(improveOutputPath) + audio_duration = get_audio_duration(audiofilePath) + framesPath = improveOutputPath + "/%5d.png" + fps = no_of_frames/audio_duration + command = f"ffmpeg -y -r {fps} -f image2 -i {framesPath} -i {audiofilePath} -vcodec mpeg4 -b:v 20000k {videoOutPath}" + print(command) + subprocess.call(command, shell=platform.system() != 'Windows') + +def get_audio_duration(audioPath): + audio = WAVE(audioPath) + duration = audio.info.length + return duration + +def count_files(directory): + return len([name for name in os.listdir(directory) if os.path.isfile(os.path.join(directory, name))]) + +def improve(disassembledPath, improvedPath): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = RealESRGAN(device, scale=4) + model.load_weights('weights/RealESRGAN_x4.pth', download=True) + + files = glob.glob(os.path.join(disassembledPath,"*.png")) + + # pool = ParallelPool(nodes=20) + # results = pool.amap(real_esrgan, files, [model]*len(files), [improvedPath] * len(files)) + results = t_map(real_esrgan, files, [model]*len(files), [improvedPath] * len(files)) + +def real_esrgan(img_path, model, improvedPath): + image = Image.open(img_path).convert('RGB') + sr_image = model.predict(image) + img_name = os.path.basename(img_path) + sr_image.save(os.path.join(improvedPath, img_name)) + + +# def process(img_path, improveOutputPath): +# only_center_face=True +# aligned=True +# ext='auto' +# weight=0.5 +# upscale=1 +# arch = 'clean' +# channel_multiplier = 2 +# model_name = 'GFPGANv1.3' +# url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth' + +# # determine model paths +# model_path = os.path.join('gfpgan_models', model_name + '.pth') +# if not os.path.isfile(model_path): +# model_path = os.path.join('gfpgan/weights', model_name + '.pth') +# if not os.path.isfile(model_path): +# # download pre-trained models from url +# model_path = url + +# restorer = GFPGANer( +# model_path=model_path, +# upscale=upscale, +# arch=arch, +# channel_multiplier=channel_multiplier, +# bg_upsampler=None) + +# # read image +# img_name = os.path.basename(img_path) +# basename, ext = os.path.splitext(img_name) +# input_img = cv2.imread(img_path, cv2.IMREAD_COLOR) + +# # restore faces and background if necessary +# cropped_faces, restored_faces, restored_img = restorer.enhance( +# input_img, +# has_aligned=aligned, +# only_center_face=only_center_face, +# paste_back=True, +# weight=weight) + +# # save faces +# for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_faces)): +# # save cropped face +# save_crop_path = os.path.join(improveOutputPath, 'cropped_faces', f'{basename}.png') +# imwrite(cropped_face, save_crop_path) +# # save restored face +# save_face_name = f'{basename}.png' +# save_restore_path = os.path.join(improveOutputPath, 'restored_faces', save_face_name) +# imwrite(restored_face, save_restore_path) +# # save comparison image +# cmp_img = np.concatenate((cropped_face, restored_face), axis=1) +# imwrite(cmp_img, os.path.join(improveOutputPath, 'cmp', f'{basename}.png')) + +# # save restored img +# if restored_img is not None: +# if ext == 'auto': +# extension = ext[1:] +# else: +# extension = ext + +# save_restore_path = os.path.join(improveOutputPath, 'restored_imgs', f'{basename}.{extension}') +# imwrite(restored_img, save_restore_path) +# print(f'Processed {img_name} ...') + +# def improve_faces(improveInputPath, improveOutputPath): +# if improveInputPath.endswith('/'): +# improveInputPath = improveInputPath[:-1] +# if os.path.isfile(improveInputPath): +# img_list = [improveInputPath] +# else: +# img_list = sorted(glob.glob(os.path.join(improveInputPath, '*'))) + +# os.makedirs(improveInputPath, exist_ok=True) +# os.makedirs(improveOutputPath, exist_ok=True) + +# pool = ParallelPool(nodes=10) +# results = pool.amap(process, img_list, [improveOutputPath] * len(img_list)) +# while not results.ready(): +# time.sleep(5); print(".", end=' ') diff --git a/lips(1).py b/lips(1).py new file mode 100644 index 0000000000000000000000000000000000000000..acf57d4bc7ce7d8f812f1792eb33498b050608b8 --- /dev/null +++ b/lips(1).py @@ -0,0 +1,232 @@ +from config import * +import os +import numpy as np +import cv2, wav2lip.audio +import subprocess +from tqdm import tqdm +import glob +import torch, wav2lip.face_detection +from wav2lip.models import Wav2Lip +import platform + + +def get_smoothened_boxes(boxes, T): + for i in range(len(boxes)): + if i + T > len(boxes): + window = boxes[len(boxes) - T:] + else: + window = boxes[i : i + T] + boxes[i] = np.mean(window, axis=0) + return boxes + +def face_detect(images): + detector = wav2lip.face_detection.FaceAlignment(wav2lip.face_detection.LandmarksType._2D, flip_input=False, device=device) + batch_size = face_det_batch_size + + while 1: + predictions = [] + try: + for i in tqdm(range(0, len(images), batch_size)): + predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size]))) + except RuntimeError: + if batch_size == 1: + raise RuntimeError('Image too big to run face detection on GPU. Please change resize_factor') + batch_size //= 2 + print('Recovering from OOM error; New batch size: {}'.format(batch_size)) + continue + break + + results = [] + pady1, pady2, padx1, padx2 = pads + for rect, image in zip(predictions, images): + if rect is None: + cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected. + raise ValueError('Face not detected! Ensure the video contains a face in all the frames.') + + y1 = max(0, rect[1] - pady1) + y2 = min(image.shape[0], rect[3] + pady2) + x1 = max(0, rect[0] - padx1) + x2 = min(image.shape[1], rect[2] + padx2) + + results.append([x1, y1, x2, y2]) + + boxes = np.array(results) + if not nosmooth: boxes = get_smoothened_boxes(boxes, T=5) + results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)] + + del detector + return results + +def datagen(frames, mels): + img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] + + if box[0] == -1: + if not static: + face_det_results = face_detect(frames) # BGR2RGB for CNN face detection + else: + face_det_results = face_detect([frames[0]]) + else: + print('Using the specified bounding box instead of face detection...') + y1, y2, x1, x2 = box + face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames] + + for i, m in enumerate(mels): + idx = 0 if static else i%len(frames) + frame_to_save = frames[idx].copy() + face, coords = face_det_results[idx].copy() + + face = cv2.resize(face, (img_size, img_size)) + + img_batch.append(face) + mel_batch.append(m) + frame_batch.append(frame_to_save) + coords_batch.append(coords) + + if len(img_batch) >= wav2lip_batch_size: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + + img_masked = img_batch.copy() + img_masked[:, img_size//2:] = 0 + + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + yield img_batch, mel_batch, frame_batch, coords_batch + img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] + + if len(img_batch) > 0: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + + img_masked = img_batch.copy() + img_masked[:, img_size//2:] = 0 + + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + yield img_batch, mel_batch, frame_batch, coords_batch + + + +def _load(checkpoint_path): + if device == 'cuda': + checkpoint = torch.load(checkpoint_path) + else: + checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) + + return checkpoint + +def load_model(path): + model = Wav2Lip() + print("Load checkpoint from: {}".format(path)) + checkpoint = _load(path) + s = checkpoint["state_dict"] + new_s = {} + for k, v in s.items(): + new_s[k.replace('module.', '')] = v + model.load_state_dict(new_s) + + model = model.to(device) + return model.eval() + +def modify_lips(path_id, audiofile, animatedfile, outfilePath): + animatedfilePath = os.path.join("temp", path_id, animatedfile) + audiofilePath = os.path.join("temp", path_id, audiofile) + tempAudioPath = os.path.join("temp", path_id, "temp.wav") + tempVideoPath = os.path.join("temp", path_id, "temp.avi") + + if not os.path.isfile(animatedfilePath): + raise ValueError('--face argument must be a valid path to video/image file') + + elif animatedfilePath.split('.')[1] in ['jpg', 'png', 'jpeg']: + full_frames = [cv2.imread(animatedfilePath)] + fps = fps + + else: + video_stream = cv2.VideoCapture(animatedfilePath) + fps = video_stream.get(cv2.CAP_PROP_FPS) + + print('Reading video frames...') + + full_frames = [] + while 1: + still_reading, frame = video_stream.read() + if not still_reading: + video_stream.release() + break + if resize_factor > 1: + frame = cv2.resize(frame, (frame.shape[1]//resize_factor, frame.shape[0]//resize_factor)) + + if rotate: + frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE) + + y1, y2, x1, x2 = crop + if x2 == -1: x2 = frame.shape[1] + if y2 == -1: y2 = frame.shape[0] + + frame = frame[y1:y2, x1:x2] + + full_frames.append(frame) + + print ("Number of frames available for inference: "+str(len(full_frames))) + + print('Extracting raw audio...') + command = 'ffmpeg -y -i {} -strict -2 {}'.format(audiofilePath, tempAudioPath) + subprocess.call(command, shell=True) + + + wav = wav2lip.audio.load_wav(tempAudioPath, 16000) + mel = wav2lip.audio.melspectrogram(wav) + print(mel.shape) + + if np.isnan(mel.reshape(-1)).sum() > 0: + raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again') + + mel_chunks = [] + mel_idx_multiplier = 80./fps + i = 0 + while 1: + start_idx = int(i * mel_idx_multiplier) + if start_idx + mel_step_size > len(mel[0]): + mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:]) + break + mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size]) + i += 1 + + print("Length of mel chunks: {}".format(len(mel_chunks))) + + full_frames = full_frames[:len(mel_chunks)] + + batch_size = wav2lip_batch_size + gen = datagen(full_frames.copy(), mel_chunks) + + for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen, total=int(np.ceil(float(len(mel_chunks))/batch_size)))): + if i == 0: + model = load_model(checkpoint_path) + print ("Model loaded") + + frame_h, frame_w = full_frames[0].shape[:-1] + out = cv2.VideoWriter(tempVideoPath, cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h)) + + img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) + mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) + + with torch.no_grad(): + pred = model(mel_batch, img_batch) + + pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. + + for p, f, c in zip(pred, frames, coords): + y1, y2, x1, x2 = c + p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1)) + + f[y1:y2, x1:x2] = p + out.write(f) + + out.release() + + command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(tempAudioPath, tempVideoPath, outfilePath) + subprocess.call(command, shell=platform.system() != 'Windows') + + + + diff --git a/lips.py b/lips.py new file mode 100644 index 0000000000000000000000000000000000000000..acf57d4bc7ce7d8f812f1792eb33498b050608b8 --- /dev/null +++ b/lips.py @@ -0,0 +1,232 @@ +from config import * +import os +import numpy as np +import cv2, wav2lip.audio +import subprocess +from tqdm import tqdm +import glob +import torch, wav2lip.face_detection +from wav2lip.models import Wav2Lip +import platform + + +def get_smoothened_boxes(boxes, T): + for i in range(len(boxes)): + if i + T > len(boxes): + window = boxes[len(boxes) - T:] + else: + window = boxes[i : i + T] + boxes[i] = np.mean(window, axis=0) + return boxes + +def face_detect(images): + detector = wav2lip.face_detection.FaceAlignment(wav2lip.face_detection.LandmarksType._2D, flip_input=False, device=device) + batch_size = face_det_batch_size + + while 1: + predictions = [] + try: + for i in tqdm(range(0, len(images), batch_size)): + predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size]))) + except RuntimeError: + if batch_size == 1: + raise RuntimeError('Image too big to run face detection on GPU. Please change resize_factor') + batch_size //= 2 + print('Recovering from OOM error; New batch size: {}'.format(batch_size)) + continue + break + + results = [] + pady1, pady2, padx1, padx2 = pads + for rect, image in zip(predictions, images): + if rect is None: + cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected. + raise ValueError('Face not detected! Ensure the video contains a face in all the frames.') + + y1 = max(0, rect[1] - pady1) + y2 = min(image.shape[0], rect[3] + pady2) + x1 = max(0, rect[0] - padx1) + x2 = min(image.shape[1], rect[2] + padx2) + + results.append([x1, y1, x2, y2]) + + boxes = np.array(results) + if not nosmooth: boxes = get_smoothened_boxes(boxes, T=5) + results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)] + + del detector + return results + +def datagen(frames, mels): + img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] + + if box[0] == -1: + if not static: + face_det_results = face_detect(frames) # BGR2RGB for CNN face detection + else: + face_det_results = face_detect([frames[0]]) + else: + print('Using the specified bounding box instead of face detection...') + y1, y2, x1, x2 = box + face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames] + + for i, m in enumerate(mels): + idx = 0 if static else i%len(frames) + frame_to_save = frames[idx].copy() + face, coords = face_det_results[idx].copy() + + face = cv2.resize(face, (img_size, img_size)) + + img_batch.append(face) + mel_batch.append(m) + frame_batch.append(frame_to_save) + coords_batch.append(coords) + + if len(img_batch) >= wav2lip_batch_size: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + + img_masked = img_batch.copy() + img_masked[:, img_size//2:] = 0 + + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + yield img_batch, mel_batch, frame_batch, coords_batch + img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] + + if len(img_batch) > 0: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + + img_masked = img_batch.copy() + img_masked[:, img_size//2:] = 0 + + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + yield img_batch, mel_batch, frame_batch, coords_batch + + + +def _load(checkpoint_path): + if device == 'cuda': + checkpoint = torch.load(checkpoint_path) + else: + checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) + + return checkpoint + +def load_model(path): + model = Wav2Lip() + print("Load checkpoint from: {}".format(path)) + checkpoint = _load(path) + s = checkpoint["state_dict"] + new_s = {} + for k, v in s.items(): + new_s[k.replace('module.', '')] = v + model.load_state_dict(new_s) + + model = model.to(device) + return model.eval() + +def modify_lips(path_id, audiofile, animatedfile, outfilePath): + animatedfilePath = os.path.join("temp", path_id, animatedfile) + audiofilePath = os.path.join("temp", path_id, audiofile) + tempAudioPath = os.path.join("temp", path_id, "temp.wav") + tempVideoPath = os.path.join("temp", path_id, "temp.avi") + + if not os.path.isfile(animatedfilePath): + raise ValueError('--face argument must be a valid path to video/image file') + + elif animatedfilePath.split('.')[1] in ['jpg', 'png', 'jpeg']: + full_frames = [cv2.imread(animatedfilePath)] + fps = fps + + else: + video_stream = cv2.VideoCapture(animatedfilePath) + fps = video_stream.get(cv2.CAP_PROP_FPS) + + print('Reading video frames...') + + full_frames = [] + while 1: + still_reading, frame = video_stream.read() + if not still_reading: + video_stream.release() + break + if resize_factor > 1: + frame = cv2.resize(frame, (frame.shape[1]//resize_factor, frame.shape[0]//resize_factor)) + + if rotate: + frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE) + + y1, y2, x1, x2 = crop + if x2 == -1: x2 = frame.shape[1] + if y2 == -1: y2 = frame.shape[0] + + frame = frame[y1:y2, x1:x2] + + full_frames.append(frame) + + print ("Number of frames available for inference: "+str(len(full_frames))) + + print('Extracting raw audio...') + command = 'ffmpeg -y -i {} -strict -2 {}'.format(audiofilePath, tempAudioPath) + subprocess.call(command, shell=True) + + + wav = wav2lip.audio.load_wav(tempAudioPath, 16000) + mel = wav2lip.audio.melspectrogram(wav) + print(mel.shape) + + if np.isnan(mel.reshape(-1)).sum() > 0: + raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again') + + mel_chunks = [] + mel_idx_multiplier = 80./fps + i = 0 + while 1: + start_idx = int(i * mel_idx_multiplier) + if start_idx + mel_step_size > len(mel[0]): + mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:]) + break + mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size]) + i += 1 + + print("Length of mel chunks: {}".format(len(mel_chunks))) + + full_frames = full_frames[:len(mel_chunks)] + + batch_size = wav2lip_batch_size + gen = datagen(full_frames.copy(), mel_chunks) + + for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen, total=int(np.ceil(float(len(mel_chunks))/batch_size)))): + if i == 0: + model = load_model(checkpoint_path) + print ("Model loaded") + + frame_h, frame_w = full_frames[0].shape[:-1] + out = cv2.VideoWriter(tempVideoPath, cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h)) + + img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) + mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) + + with torch.no_grad(): + pred = model(mel_batch, img_batch) + + pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. + + for p, f, c in zip(pred, frames, coords): + y1, y2, x1, x2 = c + p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1)) + + f[y1:y2, x1:x2] = p + out.write(f) + + out.release() + + command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(tempAudioPath, tempVideoPath, outfilePath) + subprocess.call(command, shell=platform.system() != 'Windows') + + + + diff --git a/persona.py b/persona.py new file mode 100644 index 0000000000000000000000000000000000000000..8aba3ed0afef520d012116a1705986e7b7b05c03 --- /dev/null +++ b/persona.py @@ -0,0 +1,126 @@ +from config import * +from speech import generate_speech +from image import generate_image +from lips import modify_lips +import humanize +import datetime as dt +from argparse import ArgumentParser +import shutil + +import os +import glob +from improve import improve, vid2frames, restore_frames +from animate_face import animate_face + +message = """Over the holiday season, capturing photos and videos of the festivities with family and friends + is an important activity for many. The iPhone has a suite of camera features that can significantly elevate + the quality and creativity of your holiday photos and videos.""" +#message = """Apple today confirmed that it will be permanently closing its Infinite Loop retail store in +#Cupertino, California on January 20. Infinite Loop served as Apple's headquarters between the mid-1990s and +#2017, when its current Apple Park headquarters opened a few miles away.""" + +def main(): + parser = ArgumentParser() + parser.add_argument("--improve", action="store_true", help="use Real ESRGAN to improve the video") + parser.add_argument("--skipgen", action="store_true", help="improve the video only") + parser.add_argument("--path_id", default=str(int(time.time())), help="set the path id to use") + parser.add_argument("--speech", default=audiofile, help="path to WAV speech file") + parser.add_argument("--image", default=imgfile, help="path to avatar file") + args = parser.parse_args() + tstart = time.time() + + ## SET PATH + path_id = args.path_id + path = os.path.join("temp", path_id) + print("path_id:", path_id, "path:", path) + os.makedirs(path, exist_ok=True) + outfile = os.path.join("results", path_id + "_small.mp4") + finalfile = os.path.join("results", path_id + "_large.mp4") + + if not args.skipgen: + ## GENERATE SPEECH + tspeech = "None" + if args.speech == audiofile: + print("-----------------------------------------") + print("generating speech") + t0 = time.time() + generate_speech(path_id, audiofile, "daniel", message, "ultra_fast") + tspeech = humanize.naturaldelta(dt.timedelta(seconds=int(time.time() - t0))) + print("\ngenerating speech:", tspeech) + else: + print("using:", args.speech) + shutil.copyfile(args.speech, os.path.join("temp", path_id, audiofile)) + + ## GENERATE AVATAR IMAGE + timage = "avatar.png" + shutil.copyfile(timage, os.path.join("temp", path_id, imgfile)) + shutil.copyfile(args.image, os.path.join("temp", path_id, imgfile)) + shutil.copyfile(args.image, os.path.join("temp", path_id, timage)) + #if args.image == imgfile: + #print("-----------------------------------------") + #print("generating avatar image") + #t1 = time.time() + #avatar_description = "Middle-aged black man, Idris Elba, with short dark hair, serious look" + #generate_image(path_id, imgfile, f"hyperrealistic digital avatar, centered, {avatar_description}, \ + # rim lighting, studio lighting, looking at the camera") + #timage = humanize.naturaldelta(dt.timedelta(seconds=int(time.time() - t1))) + #print("\ngenerating avatar:", timage) + #else: + #shutil.copyfile(args.image, os.path.join("temp", path_id, imgfile)) + + ## ANIMATE AVATAR IMAGE + + print("-----------------------------------------") + print("animating face with driver") + t2 = time.time() + # audiofile determines the length of the driver movie to trim + # driver movie is imposed on the image file to produce the animated file + animate_face(path_id, audiofile, driverfile, imgfile, animatedfile) + tanimate = humanize.naturaldelta(dt.timedelta(seconds=int(time.time() - t2))) + print("\nanimating face:", tanimate) + + ## MODIFY LIPS TO FIT THE SPEECH + + print("-----------------------------------------") + print("modifying lips") + t3 = time.time() + os.makedirs("results", exist_ok=True) + + modify_lips(path_id, audiofile, animatedfile, outfile) + tlips = humanize.naturaldelta(dt.timedelta(seconds=int(time.time() - t3))) + print("\nmodifying lips:", tlips) + + ## IMPROVE THE OUTPUT VIDEO + if args.improve: + t4 = time.time() + print("-----------------------------------------") + print("converting video to frames") + shutil.rmtree(os.path.join(path, "improve"), ignore_errors=True) + os.makedirs(os.path.join(path, "improve", "disassembled"), exist_ok=True) + os.makedirs(os.path.join(path, "improve", "improved"), exist_ok=True) + + vid2frames(outfile, os.path.join(path, "improve", "disassembled")) + print("-----------------------------------------") + print("improving face") + improve(os.path.join(path, "improve", "disassembled"), os.path.join(path, "improve", "improved")) + print("-----------------------------------------") + print("restoring frames") + + restore_frames(os.path.join(path, audiofile), finalfile, os.path.join(path, "improve", "improved")) + timprove = humanize.naturaldelta(dt.timedelta(seconds=int(time.time() - t4))) + print("\nimproving video:", timprove) + + print("done") + print("Overall timing") + print("--------------") + if not args.skipgen: + print("generating speech:", tspeech) + print("generating avatar image:", timage) + print("animating face:", tanimate) + print("modifying lips:", tlips) + if args.improve: + print("improving finished video:", timprove) + print("total time:", humanize.naturaldelta(dt.timedelta(seconds=int(time.time() - tstart)))) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..08c83fd5e4275bfe0b35147848700e53096e2a5f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,33 @@ +#librosa +numpy +opencv-contrib-python +opencv-python +torch +torchvision +tqdm +numba +basicsr +facexlib +lmdb +pyyaml +scipy +tb-nightly +yapf +imageio[ffmpeg] +batch-face +gdown +openai +diffusers +pathos +transformers +accelerate +mutagen +humanize +torchaudio +progressbar +p_tqdm +git+https://github.com/sberbank-ai/Real-ESRGAN.git +einops +rotary_embedding_torch +inflect +unidecode \ No newline at end of file diff --git a/results/1703435677_small.mp4 b/results/1703435677_small.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..08e4bc79776a4db3b9e748efa090f5a57c26a3d3 Binary files /dev/null and b/results/1703435677_small.mp4 differ diff --git a/results/1711455802_small.mp4 b/results/1711455802_small.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..94d795fbcc33c1264b747ebbd16c08cc0d9744d4 Binary files /dev/null and b/results/1711455802_small.mp4 differ diff --git a/results/1711514514_small.mp4 b/results/1711514514_small.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..052d479f5ea979c3d9121d088b4454e97bc48615 Binary files /dev/null and b/results/1711514514_small.mp4 differ diff --git a/results/1711515034_small.mp4 b/results/1711515034_small.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..efd6c4422020926cfb6bfea57ea92bdd21d6d686 Binary files /dev/null and b/results/1711515034_small.mp4 differ diff --git a/results/1711522443_small.mp4 b/results/1711522443_small.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..6d9dd3f0fa48b29591a9791b41bccb5adf66eda8 Binary files /dev/null and b/results/1711522443_small.mp4 differ diff --git a/results/1712691197_small.mp4 b/results/1712691197_small.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..415f4288d69d69970360e7c55e7b173f32a4b449 Binary files /dev/null and b/results/1712691197_small.mp4 differ diff --git a/results/1712731042_small.mp4 b/results/1712731042_small.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..b61e47028b3d501f482a45d9b6fe7e70b13929b7 Binary files /dev/null and b/results/1712731042_small.mp4 differ diff --git a/speech.py b/speech.py new file mode 100644 index 0000000000000000000000000000000000000000..8277b016a20e7cbef53f022d3772abaca5199285 --- /dev/null +++ b/speech.py @@ -0,0 +1,59 @@ +# from config import * +# from openai import OpenAI +# import os + +# def openai_generate_speech(audiofile, voice, text): +# client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) +# response = client.audio.speech.create( +# model="tts-1", +# voice=voice, +# input=text +# ) +# response.stream_to_file(audiofile) + +import os + +import torch +import torchaudio +import time +from tortoise.api import TextToSpeech +from tortoise.utils.audio import load_voices +import humanize +import datetime as dt + +def generate_speech(path_id, outfile, voice, text, speed="standard"): + tts = TextToSpeech(kv_cache=True, half=True) + selected_voices = voice.split(',') + for k, selected_voice in enumerate(selected_voices): + if '&' in selected_voice: + voice_sel = selected_voice.split('&') + else: + voice_sel = [selected_voice] + voice_samples, conditioning_latents = load_voices(voice_sel) + + gen, dbg_state = tts.tts_with_preset(text, k=1, voice_samples=voice_samples, + conditioning_latents=conditioning_latents, + return_deterministic_state=True, + preset=speed) + if isinstance(gen, list): + for j, g in enumerate(gen): + torchaudio.save(os.path.join("temp", path_id, outfile), g.squeeze(0).cpu(), 24000) + else: + torchaudio.save(os.path.join("temp", path_id, outfile), gen.squeeze(0).cpu(), 24000) + + + +if __name__ == '__main__': + path_id = os.path.join("temp", "audio", str(int(time.time()))) + os.makedirs(path_id, exist_ok=True) + tstart = time.time() + message = """Apple today confirmed that it will be permanently closing its Infinite Loop retail store in +Cupertino, California on January 20. Infinite Loop served as Apple's headquarters between the mid-1990s and +2017, when its current Apple Park headquarters opened a few miles away.""" + generate_speech(os.path.join("audio", str(int(time.time()))), "christmas.wav", "train_grace", + message, "ultra_fast") + + # openai_generate_speech("speech.mp3", "onyx", + # "Merry Christmas! May the holiday bring you endless joy, laughter, \ + # and quality time with friends and family!") + print("total time:", humanize.naturaldelta(dt.timedelta(seconds=int(time.time() - tstart)))) \ No newline at end of file diff --git a/temp/.DS_Store b/temp/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..04a72de05ee88842a0fe3f33a637bd882633915b Binary files /dev/null and b/temp/.DS_Store differ diff --git a/temp/1711455802/avatar.png b/temp/1711455802/avatar.png new file mode 100644 index 0000000000000000000000000000000000000000..6ca5ede9c64a79384d9a9a9d0c31c1fe7e334927 Binary files /dev/null and b/temp/1711455802/avatar.png differ diff --git a/temp/1711455802/temp.wav b/temp/1711455802/temp.wav new file mode 100644 index 0000000000000000000000000000000000000000..73c9908637ae249e456618023685b44c1bb9da8d Binary files /dev/null and b/temp/1711455802/temp.wav differ diff --git a/temp/1711455802/tmp.mp4 b/temp/1711455802/tmp.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..209216b7606b5060704d76be62fb2864d4d87a10 Binary files /dev/null and b/temp/1711455802/tmp.mp4 differ diff --git a/temp/1711514335/avatar.png b/temp/1711514335/avatar.png new file mode 100644 index 0000000000000000000000000000000000000000..5f6691fecc8170f75190422f6c17350ba157b7f5 Binary files /dev/null and b/temp/1711514335/avatar.png differ diff --git a/temp/1711514514/temp.avi b/temp/1711514514/temp.avi new file mode 100644 index 0000000000000000000000000000000000000000..4855d8400ba8f47815cfba3bac6813556899ffd3 Binary files /dev/null and b/temp/1711514514/temp.avi differ diff --git a/temp/1711514514/temp.wav b/temp/1711514514/temp.wav new file mode 100644 index 0000000000000000000000000000000000000000..019dd3b1988c9dd9955a18527e5dd74b2b780ff7 Binary files /dev/null and b/temp/1711514514/temp.wav differ diff --git a/temp/1711515034/avatar.png b/temp/1711515034/avatar.png new file mode 100644 index 0000000000000000000000000000000000000000..71cb5d6936bca1bdd9a6df543b125e444f06cb1b Binary files /dev/null and b/temp/1711515034/avatar.png differ diff --git a/temp/1711515034/temp.avi b/temp/1711515034/temp.avi new file mode 100644 index 0000000000000000000000000000000000000000..3dec53e23bb3012f8e050df1a83a7923e4d58556 Binary files /dev/null and b/temp/1711515034/temp.avi differ diff --git a/temp/1711515034/temp.wav b/temp/1711515034/temp.wav new file mode 100644 index 0000000000000000000000000000000000000000..add140fdfe9af10474d96fd55c7ce8e1c6fe2a5f Binary files /dev/null and b/temp/1711515034/temp.wav differ diff --git a/temp/1711522443/temp.avi b/temp/1711522443/temp.avi new file mode 100644 index 0000000000000000000000000000000000000000..2ea8d3672acb97eba643d334cca46141dc41fb8a Binary files /dev/null and b/temp/1711522443/temp.avi differ diff --git a/temp/1711522443/temp.wav b/temp/1711522443/temp.wav new file mode 100644 index 0000000000000000000000000000000000000000..5bb28957b16d6cca9d618e4ea1f127d537170987 Binary files /dev/null and b/temp/1711522443/temp.wav differ diff --git a/temp/1712691197/temp.avi b/temp/1712691197/temp.avi new file mode 100644 index 0000000000000000000000000000000000000000..a9576607e4910e8802c21f7e95dad43384f00b10 Binary files /dev/null and b/temp/1712691197/temp.avi differ diff --git a/temp/1712691197/temp.wav b/temp/1712691197/temp.wav new file mode 100644 index 0000000000000000000000000000000000000000..0205b2b207a05546e38e2377055645c969743f5a Binary files /dev/null and b/temp/1712691197/temp.wav differ diff --git a/temp/1712691197/tmp.wav b/temp/1712691197/tmp.wav new file mode 100644 index 0000000000000000000000000000000000000000..0205b2b207a05546e38e2377055645c969743f5a Binary files /dev/null and b/temp/1712691197/tmp.wav differ diff --git a/temp/1712731042/temp.avi b/temp/1712731042/temp.avi new file mode 100644 index 0000000000000000000000000000000000000000..8a4ab66cc84ff20eb3ece83ea0f550bd7f1fdf69 Binary files /dev/null and b/temp/1712731042/temp.avi differ diff --git a/temp/1712731042/temp.wav b/temp/1712731042/temp.wav new file mode 100644 index 0000000000000000000000000000000000000000..5ce2afcd0f014cb27d78f5ccc5e3a1eb09734d3b Binary files /dev/null and b/temp/1712731042/temp.wav differ diff --git a/temp/1712731042/tmp.wav b/temp/1712731042/tmp.wav new file mode 100644 index 0000000000000000000000000000000000000000..5ce2afcd0f014cb27d78f5ccc5e3a1eb09734d3b Binary files /dev/null and b/temp/1712731042/tmp.wav differ diff --git a/temp/faulty_frame.jpg b/temp/faulty_frame.jpg new file mode 100644 index 0000000000000000000000000000000000000000..425f2155eae0e80614bd705af207b772d27689d5 Binary files /dev/null and b/temp/faulty_frame.jpg differ diff --git a/tortoise/__init__.py b/tortoise/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tortoise/__pycache__/__init__.cpython-310.pyc b/tortoise/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2721ff5df7d50deb17956ab364d2b8d336d9736f Binary files /dev/null and b/tortoise/__pycache__/__init__.cpython-310.pyc differ diff --git a/tortoise/__pycache__/api.cpython-310.pyc b/tortoise/__pycache__/api.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f51d4a75acd178240dd1f7dc8889d14feeea184 Binary files /dev/null and b/tortoise/__pycache__/api.cpython-310.pyc differ diff --git a/tortoise/api.py b/tortoise/api.py new file mode 100644 index 0000000000000000000000000000000000000000..69807b18aced06dd7d29f169e18bef2b4ecda8b9 --- /dev/null +++ b/tortoise/api.py @@ -0,0 +1,598 @@ +import os +import random +import uuid +from time import time +from urllib import request + +import torch +import torch.nn.functional as F +import progressbar +import torchaudio + +from tortoise.models.classifier import AudioMiniEncoderWithClassifierHead +from tortoise.models.diffusion_decoder import DiffusionTts +from tortoise.models.autoregressive import UnifiedVoice +from tqdm import tqdm +from tortoise.models.arch_util import TorchMelSpectrogram +from tortoise.models.clvp import CLVP +from tortoise.models.cvvp import CVVP +from tortoise.models.random_latent_generator import RandomLatentConverter +from tortoise.models.vocoder import UnivNetGenerator +from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel +from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule +from tortoise.utils.tokenizer import VoiceBpeTokenizer +from tortoise.utils.wav2vec_alignment import Wav2VecAlignment +from contextlib import contextmanager +from huggingface_hub import hf_hub_download +pbar = None + +DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser('~'), '.cache', 'tortoise', 'models') +MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR', DEFAULT_MODELS_DIR) +MODELS = { + 'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth', + 'classifier.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth', + 'clvp2.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/clvp2.pth', + 'cvvp.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/cvvp.pth', + 'diffusion_decoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/diffusion_decoder.pth', + 'vocoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/vocoder.pth', + 'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth', + 'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth', +} + +def get_model_path(model_name, models_dir=MODELS_DIR): + """ + Get path to given model, download it if it doesn't exist. + """ + if model_name not in MODELS: + raise ValueError(f'Model {model_name} not found in available models.') + model_path = hf_hub_download(repo_id="Manmay/tortoise-tts", filename=model_name, cache_dir=models_dir) + return model_path + + +def pad_or_truncate(t, length): + """ + Utility function for forcing to have the specified sequence length, whether by clipping it or padding it with 0s. + """ + if t.shape[-1] == length: + return t + elif t.shape[-1] < length: + return F.pad(t, (0, length-t.shape[-1])) + else: + return t[..., :length] + + +def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True, cond_free_k=1): + """ + Helper function to load a GaussianDiffusion instance configured for use as a vocoder. + """ + return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon', + model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps), + conditioning_free=cond_free, conditioning_free_k=cond_free_k) + + +def format_conditioning(clip, cond_length=132300, device="cuda" if not torch.backends.mps.is_available() else 'mps'): + """ + Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models. + """ + gap = clip.shape[-1] - cond_length + if gap < 0: + clip = F.pad(clip, pad=(0, abs(gap))) + elif gap > 0: + rand_start = random.randint(0, gap) + clip = clip[:, rand_start:rand_start + cond_length] + mel_clip = TorchMelSpectrogram()(clip.unsqueeze(0)).squeeze(0) + return mel_clip.unsqueeze(0).to(device) + + +def fix_autoregressive_output(codes, stop_token, complain=True): + """ + This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was + trained on and what the autoregressive code generator creates (which has no padding or end). + This is highly specific to the DVAE being used, so this particular coding will not necessarily work if used with + a different DVAE. This can be inferred by feeding a audio clip padded with lots of zeros on the end through the DVAE + and copying out the last few codes. + + Failing to do this padding will produce speech with a harsh end that sounds like "BLAH" or similar. + """ + # Strip off the autoregressive stop token and add padding. + stop_token_indices = (codes == stop_token).nonzero() + if len(stop_token_indices) == 0: + if complain: + print("No stop tokens found in one of the generated voice clips. This typically means the spoken audio is " + "too long. In some cases, the output will still be good, though. Listen to it and if it is missing words, " + "try breaking up your input text.") + return codes + else: + codes[stop_token_indices] = 83 + stm = stop_token_indices.min().item() + codes[stm:] = 83 + if stm - 3 < codes.shape[0]: + codes[-3] = 45 + codes[-2] = 45 + codes[-1] = 248 + + return codes + + +def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_latents, temperature=1, verbose=True): + """ + Uses the specified diffusion model to convert discrete codes into a spectrogram. + """ + with torch.no_grad(): + output_seq_len = latents.shape[1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. + output_shape = (latents.shape[0], 100, output_seq_len) + precomputed_embeddings = diffusion_model.timestep_independent(latents, conditioning_latents, output_seq_len, False) + + noise = torch.randn(output_shape, device=latents.device) * temperature + mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise, + model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings}, + progress=verbose) + return denormalize_tacotron_mel(mel)[:,:,:output_seq_len] + + +def classify_audio_clip(clip): + """ + Returns whether or not Tortoises' classifier thinks the given clip came from Tortoise. + :param clip: torch tensor containing audio waveform data (get it from load_audio) + :return: True if the clip was classified as coming from Tortoise and false if it was classified as real. + """ + classifier = AudioMiniEncoderWithClassifierHead(2, spec_dim=1, embedding_dim=512, depth=5, downsample_factor=4, + resnet_blocks=2, attn_blocks=4, num_attn_heads=4, base_channels=32, + dropout=0, kernel_size=5, distribute_zero_label=False) + classifier.load_state_dict(torch.load(get_model_path('classifier.pth'), map_location=torch.device('cpu'))) + clip = clip.cpu().unsqueeze(0) + results = F.softmax(classifier(clip), dim=-1) + return results[0][0] + + +def pick_best_batch_size_for_gpu(): + """ + Tries to pick a batch size that will fit in your GPU. These sizes aren't guaranteed to work, but they should give + you a good shot. + """ + if torch.cuda.is_available(): + _, available = torch.cuda.mem_get_info() + availableGb = available / (1024 ** 3) + if availableGb > 14: + return 16 + elif availableGb > 10: + return 8 + elif availableGb > 7: + return 4 + if torch.backends.mps.is_available(): + import psutil + available = psutil.virtual_memory().total + availableGb = available / (1024 ** 3) + if availableGb > 14: + return 16 + elif availableGb > 10: + return 8 + elif availableGb > 7: + return 4 + return 1 + +class TextToSpeech: + """ + Main entry point into Tortoise. + """ + + def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, + enable_redaction=True, kv_cache=False, use_deepspeed=False, half=False, device=None, + tokenizer_vocab_file=None, tokenizer_basic=False): + + """ + Constructor + :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing + GPU OOM errors. Larger numbers generates slightly faster. + :param models_dir: Where model weights are stored. This should only be specified if you are providing your own + models, otherwise use the defaults. + :param enable_redaction: When true, text enclosed in brackets are automatically redacted from the spoken output + (but are still rendered by the model). This can be used for prompt engineering. + Default is true. + :param device: Device to use when running the model. If omitted, the device will be automatically chosen. + """ + self.models_dir = models_dir + self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size + self.enable_redaction = enable_redaction + self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu') + if torch.backends.mps.is_available(): + self.device = torch.device('mps') + if self.enable_redaction: + self.aligner = Wav2VecAlignment() + + self.tokenizer = VoiceBpeTokenizer( + vocab_file=tokenizer_vocab_file, + use_basic_cleaners=tokenizer_basic, + ) + self.half = half + if os.path.exists(f'{models_dir}/autoregressive.ptt'): + # Assume this is a traced directory. + self.autoregressive = torch.jit.load(f'{models_dir}/autoregressive.ptt') + self.diffusion = torch.jit.load(f'{models_dir}/diffusion_decoder.ptt') + else: + self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30, + model_dim=1024, + heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False, + train_solo_embeddings=False).cpu().eval() + self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir)), strict=False) + self.autoregressive.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=kv_cache, half=self.half) + + self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200, + in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16, + layer_drop=0, unconditioned_percentage=0).cpu().eval() + self.diffusion.load_state_dict(torch.load(get_model_path('diffusion_decoder.pth', models_dir))) + + self.clvp = CLVP(dim_text=768, dim_speech=768, dim_latent=768, num_text_tokens=256, text_enc_depth=20, + text_seq_len=350, text_heads=12, + num_speech_tokens=8192, speech_enc_depth=20, speech_heads=12, speech_seq_len=430, + use_xformers=True).cpu().eval() + self.clvp.load_state_dict(torch.load(get_model_path('clvp2.pth', models_dir))) + self.cvvp = None # CVVP model is only loaded if used. + + self.vocoder = UnivNetGenerator().cpu() + self.vocoder.load_state_dict(torch.load(get_model_path('vocoder.pth', models_dir), map_location=torch.device('cpu'))['model_g']) + self.vocoder.eval(inference=True) + + # Random latent generators (RLGs) are loaded lazily. + self.rlg_auto = None + self.rlg_diffusion = None + @contextmanager + def temporary_cuda(self, model): + m = model.to(self.device) + yield m + m = model.cpu() + + + def load_cvvp(self): + """Load CVVP model.""" + self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, cond_mask_percentage=0, + speech_enc_depth=8, speech_mask_percentage=0, latent_multiplier=1).cpu().eval() + self.cvvp.load_state_dict(torch.load(get_model_path('cvvp.pth', self.models_dir))) + + def get_conditioning_latents(self, voice_samples, return_mels=False): + """ + Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent). + These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic + properties. + :param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data. + """ + with torch.no_grad(): + voice_samples = [v.to(self.device) for v in voice_samples] + + auto_conds = [] + if not isinstance(voice_samples, list): + voice_samples = [voice_samples] + for vs in voice_samples: + auto_conds.append(format_conditioning(vs, device=self.device)) + auto_conds = torch.stack(auto_conds, dim=1) + self.autoregressive = self.autoregressive.to(self.device) + auto_latent = self.autoregressive.get_conditioning(auto_conds) + self.autoregressive = self.autoregressive.cpu() + + diffusion_conds = [] + for sample in voice_samples: + # The diffuser operates at a sample rate of 24000 (except for the latent inputs) + sample = torchaudio.functional.resample(sample, 22050, 24000) + sample = pad_or_truncate(sample, 102400) + cond_mel = wav_to_univnet_mel(sample.to(self.device), do_normalization=False, device=self.device) + diffusion_conds.append(cond_mel) + diffusion_conds = torch.stack(diffusion_conds, dim=1) + + self.diffusion = self.diffusion.to(self.device) + diffusion_latent = self.diffusion.get_conditioning(diffusion_conds) + self.diffusion = self.diffusion.cpu() + + if return_mels: + return auto_latent, diffusion_latent, auto_conds, diffusion_conds + else: + return auto_latent, diffusion_latent + + def get_random_conditioning_latents(self): + # Lazy-load the RLG models. + if self.rlg_auto is None: + self.rlg_auto = RandomLatentConverter(1024).eval() + self.rlg_auto.load_state_dict(torch.load(get_model_path('rlg_auto.pth', self.models_dir), map_location=torch.device('cpu'))) + self.rlg_diffusion = RandomLatentConverter(2048).eval() + self.rlg_diffusion.load_state_dict(torch.load(get_model_path('rlg_diffuser.pth', self.models_dir), map_location=torch.device('cpu'))) + with torch.no_grad(): + return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0])) + + def tts_with_preset(self, text, preset='fast', **kwargs): + """ + Calls TTS with one of a set of preset generation parameters. Options: + 'ultra_fast': Produces speech at a speed which belies the name of this repo. (Not really, but it's definitely fastest). + 'fast': Decent quality speech at a decent inference rate. A good choice for mass inference. + 'standard': Very good quality. This is generally about as good as you are going to get. + 'high_quality': Use if you want the absolute best. This is not really worth the compute, though. + """ + # Use generally found best tuning knobs for generation. + settings = {'temperature': .8, 'length_penalty': 1.0, 'repetition_penalty': 2.0, + 'top_p': .8, + 'cond_free_k': 2.0, 'diffusion_temperature': 1.0} + # Presets are defined here. + presets = { + 'ultra_fast': {'num_autoregressive_samples': 16, 'diffusion_iterations': 30, 'cond_free': False}, + 'fast': {'num_autoregressive_samples': 96, 'diffusion_iterations': 80}, + 'standard': {'num_autoregressive_samples': 256, 'diffusion_iterations': 200}, + 'high_quality': {'num_autoregressive_samples': 256, 'diffusion_iterations': 400}, + } + settings.update(presets[preset]) + settings.update(kwargs) # allow overriding of preset settings with kwargs + return self.tts(text, **settings) + + def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True, use_deterministic_seed=None, + return_deterministic_state=False, + # autoregressive generation parameters follow + num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500, + # CVVP parameters follow + cvvp_amount=.0, + # diffusion generation parameters follow + diffusion_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=1.0, + **hf_generate_kwargs): + """ + Produces an audio clip of the given text being spoken with the given reference voice. + :param text: Text to be spoken. + :param voice_samples: List of 2 or more ~10 second reference clips which should be torch tensors containing 22.05kHz waveform data. + :param conditioning_latents: A tuple of (autoregressive_conditioning_latent, diffusion_conditioning_latent), which + can be provided in lieu of voice_samples. This is ignored unless voice_samples=None. + Conditioning latents can be retrieved via get_conditioning_latents(). + :param k: The number of returned clips. The most likely (as determined by Tortoises' CLVP model) clips are returned. + :param verbose: Whether or not to print log messages indicating the progress of creating a clip. Default=true. + ~~AUTOREGRESSIVE KNOBS~~ + :param num_autoregressive_samples: Number of samples taken from the autoregressive model, all of which are filtered using CLVP. + As Tortoise is a probabilistic model, more samples means a higher probability of creating something "great". + :param temperature: The softmax temperature of the autoregressive model. + :param length_penalty: A length penalty applied to the autoregressive decoder. Higher settings causes the model to produce more terse outputs. + :param repetition_penalty: A penalty that prevents the autoregressive decoder from repeating itself during decoding. Can be used to reduce the incidence + of long silences or "uhhhhhhs", etc. + :param top_p: P value used in nucleus sampling. (0,1]. Lower values mean the decoder produces more "likely" (aka boring) outputs. + :param max_mel_tokens: Restricts the output length. (0,600] integer. Each unit is 1/20 of a second. + :param typical_sampling: Turns typical sampling on or off. This sampling mode is discussed in this paper: https://arxiv.org/abs/2202.00666 + I was interested in the premise, but the results were not as good as I was hoping. This is off by default, but + could use some tuning. + :param typical_mass: The typical_mass parameter from the typical_sampling algorithm. + ~~CLVP-CVVP KNOBS~~ + :param cvvp_amount: Controls the influence of the CVVP model in selecting the best output from the autoregressive model. + [0,1]. Values closer to 1 mean the CVVP model is more important, 0 disables the CVVP model. + ~~DIFFUSION KNOBS~~ + :param diffusion_iterations: Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine + the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better, + however. + :param cond_free: Whether or not to perform conditioning-free diffusion. Conditioning-free diffusion performs two forward passes for + each diffusion step: one with the outputs of the autoregressive model and one with no conditioning priors. The output + of the two is blended according to the cond_free_k value below. Conditioning-free diffusion is the real deal, and + dramatically improves realism. + :param cond_free_k: Knob that determines how to balance the conditioning free signal with the conditioning-present signal. [0,inf]. + As cond_free_k increases, the output becomes dominated by the conditioning-free signal. + Formula is: output=cond_present_output*(cond_free_k+1)-cond_absenct_output*cond_free_k + :param diffusion_temperature: Controls the variance of the noise fed into the diffusion model. [0,1]. Values at 0 + are the "mean" prediction of the diffusion network and will sound bland and smeared. + ~~OTHER STUFF~~ + :param hf_generate_kwargs: The huggingface Transformers generate API is used for the autoregressive transformer. + Extra keyword args fed to this function get forwarded directly to that API. Documentation + here: https://huggingface.co/docs/transformers/internal/generation_utils + :return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. + Sample rate is 24kHz. + """ + deterministic_seed = self.deterministic_state(seed=use_deterministic_seed) + + text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device) + text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. + assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.' + auto_conds = None + if voice_samples is not None: + auto_conditioning, diffusion_conditioning, auto_conds, _ = self.get_conditioning_latents(voice_samples, return_mels=True) + elif conditioning_latents is not None: + auto_conditioning, diffusion_conditioning = conditioning_latents + else: + auto_conditioning, diffusion_conditioning = self.get_random_conditioning_latents() + auto_conditioning = auto_conditioning.to(self.device) + diffusion_conditioning = diffusion_conditioning.to(self.device) + + diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k) + + with torch.no_grad(): + samples = [] + num_batches = num_autoregressive_samples // self.autoregressive_batch_size + stop_mel_token = self.autoregressive.stop_mel_token + calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output" + if verbose: + print("Generating autoregressive samples..") + if not torch.backends.mps.is_available(): + with self.temporary_cuda(self.autoregressive + ) as autoregressive, torch.autocast(device_type="cuda", dtype=torch.float16, enabled=self.half): + for b in tqdm(range(num_batches), disable=not verbose): + codes = autoregressive.inference_speech(auto_conditioning, text_tokens, + do_sample=True, + top_p=top_p, + temperature=temperature, + num_return_sequences=self.autoregressive_batch_size, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + max_generate_length=max_mel_tokens, + **hf_generate_kwargs) + padding_needed = max_mel_tokens - codes.shape[1] + codes = F.pad(codes, (0, padding_needed), value=stop_mel_token) + samples.append(codes) + else: + with self.temporary_cuda(self.autoregressive) as autoregressive: + for b in tqdm(range(num_batches), disable=not verbose): + codes = autoregressive.inference_speech(auto_conditioning, text_tokens, + do_sample=True, + top_p=top_p, + temperature=temperature, + num_return_sequences=self.autoregressive_batch_size, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + max_generate_length=max_mel_tokens, + **hf_generate_kwargs) + padding_needed = max_mel_tokens - codes.shape[1] + codes = F.pad(codes, (0, padding_needed), value=stop_mel_token) + samples.append(codes) + + clip_results = [] + + if not torch.backends.mps.is_available(): + with self.temporary_cuda(self.clvp) as clvp, torch.autocast( + device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16, enabled=self.half + ): + if cvvp_amount > 0: + if self.cvvp is None: + self.load_cvvp() + self.cvvp = self.cvvp.to(self.device) + if verbose: + if self.cvvp is None: + print("Computing best candidates using CLVP") + else: + print(f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%") + for batch in tqdm(samples, disable=not verbose): + for i in range(batch.shape[0]): + batch[i] = fix_autoregressive_output(batch[i], stop_mel_token) + if cvvp_amount != 1: + clvp_out = clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False) + if auto_conds is not None and cvvp_amount > 0: + cvvp_accumulator = 0 + for cl in range(auto_conds.shape[1]): + cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False) + cvvp = cvvp_accumulator / auto_conds.shape[1] + if cvvp_amount == 1: + clip_results.append(cvvp) + else: + clip_results.append(cvvp * cvvp_amount + clvp_out * (1-cvvp_amount)) + else: + clip_results.append(clvp_out) + clip_results = torch.cat(clip_results, dim=0) + samples = torch.cat(samples, dim=0) + best_results = samples[torch.topk(clip_results, k=k).indices] + else: + with self.temporary_cuda(self.clvp) as clvp: + if cvvp_amount > 0: + if self.cvvp is None: + self.load_cvvp() + self.cvvp = self.cvvp.to(self.device) + if verbose: + if self.cvvp is None: + print("Computing best candidates using CLVP") + else: + print(f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%") + for batch in tqdm(samples, disable=not verbose): + for i in range(batch.shape[0]): + batch[i] = fix_autoregressive_output(batch[i], stop_mel_token) + if cvvp_amount != 1: + clvp_out = clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False) + if auto_conds is not None and cvvp_amount > 0: + cvvp_accumulator = 0 + for cl in range(auto_conds.shape[1]): + cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False) + cvvp = cvvp_accumulator / auto_conds.shape[1] + if cvvp_amount == 1: + clip_results.append(cvvp) + else: + clip_results.append(cvvp * cvvp_amount + clvp_out * (1-cvvp_amount)) + else: + clip_results.append(clvp_out) + clip_results = torch.cat(clip_results, dim=0) + samples = torch.cat(samples, dim=0) + best_results = samples[torch.topk(clip_results, k=k).indices] + if self.cvvp is not None: + self.cvvp = self.cvvp.cpu() + del samples + + # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning + # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these + # results, but will increase memory usage. + if not torch.backends.mps.is_available(): + with self.temporary_cuda( + self.autoregressive + ) as autoregressive, torch.autocast( + device_type="cuda" if not torch.backends.mps.is_available() else 'mps', dtype=torch.float16, enabled=self.half + ): + best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1), + torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results, + torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device), + return_latent=True, clip_inputs=False) + del auto_conditioning + else: + with self.temporary_cuda( + self.autoregressive + ) as autoregressive: + best_latents = autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1), + torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results, + torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device), + return_latent=True, clip_inputs=False) + del auto_conditioning + + if verbose: + print("Transforming autoregressive outputs into audio..") + wav_candidates = [] + if not torch.backends.mps.is_available(): + with self.temporary_cuda(self.diffusion) as diffusion, self.temporary_cuda( + self.vocoder + ) as vocoder: + for b in range(best_results.shape[0]): + codes = best_results[b].unsqueeze(0) + latents = best_latents[b].unsqueeze(0) + + # Find the first occurrence of the "calm" token and trim the codes to that. + ctokens = 0 + for k in range(codes.shape[-1]): + if codes[0, k] == calm_token: + ctokens += 1 + else: + ctokens = 0 + if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech. + latents = latents[:, :k] + break + mel = do_spectrogram_diffusion(diffusion, diffuser, latents, diffusion_conditioning, temperature=diffusion_temperature, + verbose=verbose) + wav = vocoder.inference(mel) + wav_candidates.append(wav.cpu()) + else: + diffusion, vocoder = self.diffusion, self.vocoder + diffusion_conditioning = diffusion_conditioning.cpu() + for b in range(best_results.shape[0]): + codes = best_results[b].unsqueeze(0).cpu() + latents = best_latents[b].unsqueeze(0).cpu() + + # Find the first occurrence of the "calm" token and trim the codes to that. + ctokens = 0 + for k in range(codes.shape[-1]): + if codes[0, k] == calm_token: + ctokens += 1 + else: + ctokens = 0 + if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech. + latents = latents[:, :k] + break + mel = do_spectrogram_diffusion(diffusion, diffuser, latents, diffusion_conditioning, temperature=diffusion_temperature, + verbose=verbose) + wav = vocoder.inference(mel) + wav_candidates.append(wav.cpu()) + + def potentially_redact(clip, text): + if self.enable_redaction: + return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1) + return clip + wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates] + + if len(wav_candidates) > 1: + res = wav_candidates + else: + res = wav_candidates[0] + + if return_deterministic_state: + return res, (deterministic_seed, text, voice_samples, conditioning_latents) + else: + return res + def deterministic_state(self, seed=None): + """ + Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be + reproduced. + """ + seed = int(time()) if seed is None else seed + torch.manual_seed(seed) + random.seed(seed) + # Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary. + # torch.use_deterministic_algorithms(True) + + return seed diff --git a/tortoise/api_fast.py b/tortoise/api_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..216ecf2df06135b3d53e012f4b7089216a5e9523 --- /dev/null +++ b/tortoise/api_fast.py @@ -0,0 +1,511 @@ +import os +import random +import uuid +from time import time +from urllib import request + +import torch +import torch.nn.functional as F +import progressbar +import torchaudio +import numpy as np +from tortoise.models.classifier import AudioMiniEncoderWithClassifierHead +from tortoise.models.diffusion_decoder import DiffusionTts +from tortoise.models.autoregressive import UnifiedVoice +from tqdm import tqdm +from tortoise.models.arch_util import TorchMelSpectrogram +from tortoise.models.clvp import CLVP +from tortoise.models.cvvp import CVVP +from tortoise.models.hifigan_decoder import HifiganGenerator +from tortoise.models.random_latent_generator import RandomLatentConverter +from tortoise.models.vocoder import UnivNetGenerator +from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel +from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule +from tortoise.utils.tokenizer import VoiceBpeTokenizer +from tortoise.utils.wav2vec_alignment import Wav2VecAlignment +from contextlib import contextmanager +from tortoise.models.stream_generator import init_stream_support +from huggingface_hub import hf_hub_download +pbar = None +init_stream_support() +DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser('~'), '.cache', 'tortoise', 'models') +MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR', DEFAULT_MODELS_DIR) + +MODELS = { + 'autoregressive.pth': 'https://huggingface.co/Manmay/tortoise-tts/resolve/main/autoregressive.pth', + 'classifier.pth': 'https://huggingface.co/Manmay/tortoise-tts/resolve/main/classifier.pth', + 'rlg_auto.pth': 'https://huggingface.co/Manmay/tortoise-tts/resolve/main/rlg_auto.pth', + 'hifidecoder.pth': 'https://huggingface.co/Manmay/tortoise-tts/resolve/main/hifidecoder.pth', +} + +def get_model_path(model_name, models_dir=MODELS_DIR): + """ + Get path to given model, download it if it doesn't exist. + """ + if model_name not in MODELS: + raise ValueError(f'Model {model_name} not found in available models.') + model_path = hf_hub_download(repo_id="Manmay/tortoise-tts", filename=model_name, cache_dir=models_dir) + return model_path + + +def pad_or_truncate(t, length): + """ + Utility function for forcing to have the specified sequence length, whether by clipping it or padding it with 0s. + """ + if t.shape[-1] == length: + return t + elif t.shape[-1] < length: + return F.pad(t, (0, length-t.shape[-1])) + else: + return t[..., :length] + + +def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True, cond_free_k=1): + """ + Helper function to load a GaussianDiffusion instance configured for use as a vocoder. + """ + return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon', + model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps), + conditioning_free=cond_free, conditioning_free_k=cond_free_k) + + +def format_conditioning(clip, cond_length=132300, device="cuda" if not torch.backends.mps.is_available() else 'mps'): + """ + Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models. + """ + gap = clip.shape[-1] - cond_length + if gap < 0: + clip = F.pad(clip, pad=(0, abs(gap))) + elif gap > 0: + rand_start = random.randint(0, gap) + clip = clip[:, rand_start:rand_start + cond_length] + mel_clip = TorchMelSpectrogram()(clip.unsqueeze(0)).squeeze(0) + return mel_clip.unsqueeze(0).to(device) + + +def fix_autoregressive_output(codes, stop_token, complain=True): + """ + This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was + trained on and what the autoregressive code generator creates (which has no padding or end). + This is highly specific to the DVAE being used, so this particular coding will not necessarily work if used with + a different DVAE. This can be inferred by feeding a audio clip padded with lots of zeros on the end through the DVAE + and copying out the last few codes. + + Failing to do this padding will produce speech with a harsh end that sounds like "BLAH" or similar. + """ + # Strip off the autoregressive stop token and add padding. + stop_token_indices = (codes == stop_token).nonzero() + if len(stop_token_indices) == 0: + if complain: + print("No stop tokens found in one of the generated voice clips. This typically means the spoken audio is " + "too long. In some cases, the output will still be good, though. Listen to it and if it is missing words, " + "try breaking up your input text.") + return codes + else: + codes[stop_token_indices] = 83 + stm = stop_token_indices.min().item() + codes[stm:] = 83 + if stm - 3 < codes.shape[0]: + codes[-3] = 45 + codes[-2] = 45 + codes[-1] = 248 + + return codes + + +def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_latents, temperature=1, verbose=True): + """ + Uses the specified diffusion model to convert discrete codes into a spectrogram. + """ + with torch.no_grad(): + output_seq_len = latents.shape[1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. + output_shape = (latents.shape[0], 100, output_seq_len) + precomputed_embeddings = diffusion_model.timestep_independent(latents, conditioning_latents, output_seq_len, False) + + noise = torch.randn(output_shape, device=latents.device) * temperature + mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise, + model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings}, + progress=verbose) + return denormalize_tacotron_mel(mel)[:,:,:output_seq_len] + + +def classify_audio_clip(clip): + """ + Returns whether or not Tortoises' classifier thinks the given clip came from Tortoise. + :param clip: torch tensor containing audio waveform data (get it from load_audio) + :return: True if the clip was classified as coming from Tortoise and false if it was classified as real. + """ + classifier = AudioMiniEncoderWithClassifierHead(2, spec_dim=1, embedding_dim=512, depth=5, downsample_factor=4, + resnet_blocks=2, attn_blocks=4, num_attn_heads=4, base_channels=32, + dropout=0, kernel_size=5, distribute_zero_label=False) + classifier.load_state_dict(torch.load(get_model_path('classifier.pth'), map_location=torch.device('cpu'))) + clip = clip.cpu().unsqueeze(0) + results = F.softmax(classifier(clip), dim=-1) + return results[0][0] + + +def pick_best_batch_size_for_gpu(): + """ + Tries to pick a batch size that will fit in your GPU. These sizes aren't guaranteed to work, but they should give + you a good shot. + """ + if torch.cuda.is_available(): + _, available = torch.cuda.mem_get_info() + availableGb = available / (1024 ** 3) + if availableGb > 14: + return 16 + elif availableGb > 10: + return 8 + elif availableGb > 7: + return 4 + if torch.backends.mps.is_available(): + import psutil + available = psutil.virtual_memory().total + availableGb = available / (1024 ** 3) + if availableGb > 14: + return 16 + elif availableGb > 10: + return 8 + elif availableGb > 7: + return 4 + return 1 + +class TextToSpeech: + """ + Main entry point into Tortoise. + """ + + def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, + enable_redaction=True, kv_cache=False, use_deepspeed=False, half=False, device=None, + tokenizer_vocab_file=None, tokenizer_basic=False): + + """ + Constructor + :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing + GPU OOM errors. Larger numbers generates slightly faster. + :param models_dir: Where model weights are stored. This should only be specified if you are providing your own + models, otherwise use the defaults. + :param enable_redaction: When true, text enclosed in brackets are automatically redacted from the spoken output + (but are still rendered by the model). This can be used for prompt engineering. + Default is true. + :param device: Device to use when running the model. If omitted, the device will be automatically chosen. + """ + self.models_dir = models_dir + self.autoregressive_batch_size = pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size + self.enable_redaction = enable_redaction + self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu') + if torch.backends.mps.is_available(): + self.device = torch.device('mps') + if self.enable_redaction: + self.aligner = Wav2VecAlignment() + + self.tokenizer = VoiceBpeTokenizer( + vocab_file=tokenizer_vocab_file, + use_basic_cleaners=tokenizer_basic, + ) + self.half = half + if os.path.exists(f'{models_dir}/autoregressive.ptt'): + # Assume this is a traced directory. + self.autoregressive = torch.jit.load(f'{models_dir}/autoregressive.ptt') + else: + self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30, + model_dim=1024, + heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False, + train_solo_embeddings=False).to(self.device).eval() + self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir)), strict=False) + self.autoregressive.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=kv_cache, half=self.half) + + self.hifi_decoder = HifiganGenerator(in_channels=1024, out_channels = 1, resblock_type = "1", + resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], resblock_kernel_sizes = [3, 7, 11], + upsample_kernel_sizes = [16, 16, 4, 4], upsample_initial_channel = 512, upsample_factors = [8, 8, 2, 2], + cond_channels=1024).to(self.device).eval() + hifi_model = torch.load(get_model_path('hifidecoder.pth')) + self.hifi_decoder.load_state_dict(hifi_model, strict=False) + # Random latent generators (RLGs) are loaded lazily. + self.rlg_auto = None + def get_conditioning_latents(self, voice_samples, return_mels=False): + """ + Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent). + These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic + properties. + :param voice_samples: List of 2 or more ~10 second reference clips, which should be torch tensors containing 22.05kHz waveform data. + """ + with torch.no_grad(): + voice_samples = [v.to(self.device) for v in voice_samples] + + auto_conds = [] + if not isinstance(voice_samples, list): + voice_samples = [voice_samples] + for vs in voice_samples: + auto_conds.append(format_conditioning(vs, device=self.device)) + auto_conds = torch.stack(auto_conds, dim=1) + auto_latent = self.autoregressive.get_conditioning(auto_conds) + + if return_mels: + return auto_latent + else: + return auto_latent + + def get_random_conditioning_latents(self): + # Lazy-load the RLG models. + if self.rlg_auto is None: + self.rlg_auto = RandomLatentConverter(1024).eval() + self.rlg_auto.load_state_dict(torch.load(get_model_path('rlg_auto.pth', self.models_dir), map_location=torch.device('cpu'))) + with torch.no_grad(): + return self.rlg_auto(torch.tensor([0.0])) + + def tts_with_preset(self, text, preset='fast', **kwargs): + """ + Calls TTS with one of a set of preset generation parameters. Options: + 'ultra_fast': Produces speech at a speed which belies the name of this repo. (Not really, but it's definitely fastest). + 'fast': Decent quality speech at a decent inference rate. A good choice for mass inference. + 'standard': Very good quality. This is generally about as good as you are going to get. + 'high_quality': Use if you want the absolute best. This is not really worth the compute, though. + """ + # Use generally found best tuning knobs for generation. + settings = {'temperature': .8, 'length_penalty': 1.0, 'repetition_penalty': 2.0, + 'top_p': .8, + 'cond_free_k': 2.0, 'diffusion_temperature': 1.0} + # Presets are defined here. + presets = { + 'ultra_fast': {'num_autoregressive_samples': 1, 'diffusion_iterations': 10}, + 'fast': {'num_autoregressive_samples': 32, 'diffusion_iterations': 50}, + 'standard': {'num_autoregressive_samples': 256, 'diffusion_iterations': 200}, + 'high_quality': {'num_autoregressive_samples': 256, 'diffusion_iterations': 400}, + } + settings.update(presets[preset]) + settings.update(kwargs) # allow overriding of preset settings with kwargs + for audio_frame in self.tts(text, **settings): + yield audio_frame + # taken from here https://github.com/coqui-ai/TTS/blob/b4c552a112fd4c5f3477f439882eb43c2e2ce85f/TTS/tts/models/xtts.py#L600 + def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len): + """Handle chunk formatting in streaming mode""" + wav_chunk = wav_gen[:-overlap_len] + if wav_gen_prev is not None: + wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) : -overlap_len] + if wav_overlap is not None: + # cross fade the overlap section + if overlap_len > len(wav_chunk): + # wav_chunk is smaller than overlap_len, pass on last wav_gen + if wav_gen_prev is not None: + wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len):] + else: + # not expecting will hit here as problem happens on last chunk + wav_chunk = wav_gen[-overlap_len:] + return wav_chunk, wav_gen, None + else: + crossfade_wav = wav_chunk[:overlap_len] + crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device) + wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device) + wav_chunk[:overlap_len] += crossfade_wav + + wav_overlap = wav_gen[-overlap_len:] + wav_gen_prev = wav_gen + return wav_chunk, wav_gen_prev, wav_overlap + + + def tts_stream(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True, use_deterministic_seed=None, + return_deterministic_state=False, overlap_wav_len=1024, stream_chunk_size=40, + # autoregressive generation parameters follow + num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500, + # CVVP parameters follow + cvvp_amount=.0, + # diffusion generation parameters follow + diffusion_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=1.0, + **hf_generate_kwargs): + """ + Produces an audio clip of the given text being spoken with the given reference voice. + :param text: Text to be spoken. + :param voice_samples: List of 2 or more ~10 second reference clips which should be torch tensors containing 22.05kHz waveform data. + :param conditioning_latents: A tuple of (autoregressive_conditioning_latent, diffusion_conditioning_latent), which + can be provided in lieu of voice_samples. This is ignored unless voice_samples=None. + Conditioning latents can be retrieved via get_conditioning_latents(). + :param k: The number of returned clips. The most likely (as determined by Tortoises' CLVP model) clips are returned. + :param verbose: Whether or not to print log messages indicating the progress of creating a clip. Default=true. + ~~AUTOREGRESSIVE KNOBS~~ + :param num_autoregressive_samples: Number of samples taken from the autoregressive model, all of which are filtered using CLVP. + As Tortoise is a probabilistic model, more samples means a higher probability of creating something "great". + :param temperature: The softmax temperature of the autoregressive model. + :param length_penalty: A length penalty applied to the autoregressive decoder. Higher settings causes the model to produce more terse outputs. + :param repetition_penalty: A penalty that prevents the autoregressive decoder from repeating itself during decoding. Can be used to reduce the incidence + of long silences or "uhhhhhhs", etc. + :param top_p: P value used in nucleus sampling. (0,1]. Lower values mean the decoder produces more "likely" (aka boring) outputs. + :param max_mel_tokens: Restricts the output length. (0,600] integer. Each unit is 1/20 of a second. + ~~DIFFUSION KNOBS~~ + :param diffusion_iterations: Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine + the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better, + however. + :param cond_free: Whether or not to perform conditioning-free diffusion. Conditioning-free diffusion performs two forward passes for + each diffusion step: one with the outputs of the autoregressive model and one with no conditioning priors. The output + of the two is blended according to the cond_free_k value below. Conditioning-free diffusion is the real deal, and + dramatically improves realism. + :param cond_free_k: Knob that determines how to balance the conditioning free signal with the conditioning-present signal. [0,inf]. + As cond_free_k increases, the output becomes dominated by the conditioning-free signal. + Formula is: output=cond_present_output*(cond_free_k+1)-cond_absenct_output*cond_free_k + :param diffusion_temperature: Controls the variance of the noise fed into the diffusion model. [0,1]. Values at 0 + are the "mean" prediction of the diffusion network and will sound bland and smeared. + ~~OTHER STUFF~~ + :param hf_generate_kwargs: The huggingface Transformers generate API is used for the autoregressive transformer. + Extra keyword args fed to this function get forwarded directly to that API. Documentation + here: https://huggingface.co/docs/transformers/internal/generation_utils + :return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. + Sample rate is 24kHz. + """ + deterministic_seed = self.deterministic_state(seed=use_deterministic_seed) + + text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device) + text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. + assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.' + if voice_samples is not None: + auto_conditioning = self.get_conditioning_latents(voice_samples, return_mels=False) + else: + auto_conditioning = self.get_random_conditioning_latents() + auto_conditioning = auto_conditioning.to(self.device) + + with torch.no_grad(): + calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output" + if verbose: + print("Generating autoregressive samples..") + with torch.autocast( + device_type="cuda" , dtype=torch.float16, enabled=self.half + ): + fake_inputs = self.autoregressive.compute_embeddings( + auto_conditioning, + text_tokens, + ) + gpt_generator = self.autoregressive.get_generator( + fake_inputs=fake_inputs, + top_k=50, + top_p=top_p, + temperature=temperature, + do_sample=True, + num_beams=1, + num_return_sequences=1, + length_penalty=float(length_penalty), + repetition_penalty=float(repetition_penalty), + output_attentions=False, + output_hidden_states=True, + **hf_generate_kwargs, + ) + all_latents = [] + codes_ = [] + wav_gen_prev = None + wav_overlap = None + is_end = False + first_buffer = 60 + while not is_end: + try: + with torch.autocast( + device_type="cuda", dtype=torch.float16, enabled=self.half + ): + codes, latent = next(gpt_generator) + all_latents += [latent] + codes_ += [codes] + except StopIteration: + is_end = True + + if is_end or (stream_chunk_size > 0 and len(codes_) >= max(stream_chunk_size, first_buffer)): + first_buffer = 0 + gpt_latents = torch.cat(all_latents, dim=0)[None, :] + wav_gen = self.hifi_decoder.inference(gpt_latents.to(self.device), auto_conditioning) + wav_gen = wav_gen.squeeze() + wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks( + wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len + ) + codes_ = [] + yield wav_chunk + def tts(self, text, voice_samples=None, k=1, verbose=True, use_deterministic_seed=None, + # autoregressive generation parameters follow + num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, + top_p=.8, max_mel_tokens=500, + # CVVP parameters follow + cvvp_amount=.0, + **hf_generate_kwargs): + """ + Produces an audio clip of the given text being spoken with the given reference voice. + :param text: Text to be spoken. + :param voice_samples: List of 2 or more ~10 second reference clips which should be torch tensors containing 22.05kHz waveform data. + :param conditioning_latents: A tuple of (autoregressive_conditioning_latent, diffusion_conditioning_latent), which + can be provided in lieu of voice_samples. This is ignored unless voice_samples=None. + Conditioning latents can be retrieved via get_conditioning_latents(). + :param k: The number of returned clips. The most likely (as determined by Tortoises' CLVP model) clips are returned. + :param verbose: Whether or not to print log messages indicating the progress of creating a clip. Default=true. + ~~AUTOREGRESSIVE KNOBS~~ + :param num_autoregressive_samples: Number of samples taken from the autoregressive model, all of which are filtered using CLVP. + As Tortoise is a probabilistic model, more samples means a higher probability of creating something "great". + :param temperature: The softmax temperature of the autoregressive model. + :param length_penalty: A length penalty applied to the autoregressive decoder. Higher settings causes the model to produce more terse outputs. + :param repetition_penalty: A penalty that prevents the autoregressive decoder from repeating itself during decoding. Can be used to reduce the incidence + of long silences or "uhhhhhhs", etc. + :param top_p: P value used in nucleus sampling. (0,1]. Lower values mean the decoder produces more "likely" (aka boring) outputs. + :param max_mel_tokens: Restricts the output length. (0,600] integer. Each unit is 1/20 of a second. + ~~DIFFUSION KNOBS~~ + :param diffusion_iterations: Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine + the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better, + however. + :param cond_free: Whether or not to perform conditioning-free diffusion. Conditioning-free diffusion performs two forward passes for + each diffusion step: one with the outputs of the autoregressive model and one with no conditioning priors. The output + of the two is blended according to the cond_free_k value below. Conditioning-free diffusion is the real deal, and + dramatically improves realism. + :param cond_free_k: Knob that determines how to balance the conditioning free signal with the conditioning-present signal. [0,inf]. + As cond_free_k increases, the output becomes dominated by the conditioning-free signal. + Formula is: output=cond_present_output*(cond_free_k+1)-cond_absenct_output*cond_free_k + :param diffusion_temperature: Controls the variance of the noise fed into the diffusion model. [0,1]. Values at 0 + are the "mean" prediction of the diffusion network and will sound bland and smeared. + ~~OTHER STUFF~~ + :param hf_generate_kwargs: The huggingface Transformers generate API is used for the autoregressive transformer. + Extra keyword args fed to this function get forwarded directly to that API. Documentation + here: https://huggingface.co/docs/transformers/internal/generation_utils + :return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. + Sample rate is 24kHz. + """ + deterministic_seed = self.deterministic_state(seed=use_deterministic_seed) + + text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device) + text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. + assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.' + if voice_samples is not None: + auto_conditioning = self.get_conditioning_latents(voice_samples, return_mels=False) + else: + auto_conditioning = self.get_random_conditioning_latents() + auto_conditioning = auto_conditioning.to(self.device) + + with torch.no_grad(): + calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output" + if verbose: + print("Generating autoregressive samples..") + with torch.autocast( + device_type="cuda" , dtype=torch.float16, enabled=self.half + ): + codes = self.autoregressive.inference_speech(auto_conditioning, text_tokens, + top_k=50, + top_p=top_p, + temperature=temperature, + do_sample=True, + num_beams=1, + num_return_sequences=1, + length_penalty=float(length_penalty), + repetition_penalty=float(repetition_penalty), + output_attentions=False, + output_hidden_states=True, + **hf_generate_kwargs) + gpt_latents = self.autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1), + torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes, + torch.tensor([codes.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device), + return_latent=True, clip_inputs=False) + if verbose: + print("generating audio..") + wav_gen = self.hifi_decoder.inference(gpt_latents.to(self.device), auto_conditioning) + return wav_gen + def deterministic_state(self, seed=None): + """ + Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be + reproduced. + """ + seed = int(time()) if seed is None else seed + torch.manual_seed(seed) + random.seed(seed) + # Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary. + # torch.use_deterministic_algorithms(True) + + return seed diff --git a/tortoise/data/got.txt b/tortoise/data/got.txt new file mode 100644 index 0000000000000000000000000000000000000000..a7180b9a224d40d824807fefe86d592b709216ad --- /dev/null +++ b/tortoise/data/got.txt @@ -0,0 +1,276 @@ +Chapter One + + +Bran + + +The morning had dawned clear and cold, with a crispness that hinted at the end of summer. They set forth at daybreak to see a man beheaded, twenty in all, and Bran rode among them, nervous with excitement. This was the first time he had been deemed old enough to go with his lord father and his brothers to see the king's justice done. It was the ninth year of summer, and the seventh of Bran's life. + + +The man had been taken outside a small holdfast in the hills. Robb thought he was a wildling, his sword sworn to Mance Rayder, the King-beyond-the-Wall. It made Bran's skin prickle to think of it. He remembered the hearth tales Old Nan told them. The wildlings were cruel men, she said, slavers and slayers and thieves. They consorted with giants and ghouls, stole girl children in the dead of night, and drank blood from polished horns. And their women lay with the Others in the Long Night to sire terrible half-human children. + + +But the man they found bound hand and foot to the holdfast wall awaiting the king's justice was old and scrawny, not much taller than Robb. He had lost both ears and a finger to frostbite, and he dressed all in black, the same as a brother of the Night's Watch, except that his furs were ragged and greasy. + + +The breath of man and horse mingled, steaming, in the cold morning air as his lord father had the man cut down from the wall and dragged before them. Robb and Jon sat tall and still on their horses, with Bran between them on his pony, trying to seem older than seven, trying to pretend that he'd seen all this before. A faint wind blew through the holdfast gate. Over their heads flapped the banner of the Starks of Winterfell: a grey direwolf racing across an ice-white field. + +Bran's father sat solemnly on his horse, long brown hair stirring in the wind. His closely trimmed beard was shot with white, making him look older than his thirty-five years. He had a grim cast to his grey eyes this day, and he seemed not at all the man who would sit before the fire in the evening and talk softly of the age of heroes and the children of the forest. He had taken off Father's face, Bran thought, and donned the face of Lord Stark of Winterfell. + + +There were questions asked and answers given there in the chill of morning, but afterward Bran could not recall much of what had been said. Finally his lord father gave a command, and two of his guardsmen dragged the ragged man to the ironwood stump in the center of the square. They forced his head down onto the hard black wood. Lord Eddard Stark dismounted and his ward Theon Greyjoy brought forth the sword. "Ice," that sword was called. It was as wide across as a man's hand, and taller even than Robb. The blade was Valyrian steel, spell-forged and dark as smoke. Nothing held an edge like Valyrian steel. + + +His father peeled off his gloves and handed them to Jory Cassel, the captain of his household guard. He took hold of Ice with both hands and said, "In the name of Robert of the House Baratheon, the First of his Name, King of the Andals and the Rhoynar and the First Men, Lord of the Seven Kingdoms and Protector of the Realm, by the word of Eddard of the House Stark, Lord of Winterfell and Warden of the North, I do sentence you to die." He lifted the greatsword high above his head. + + +Bran's bastard brother Jon Snow moved closer. "Keep the pony well in hand," he whispered. "And don't look away. Father will know if you do." + + +Bran kept his pony well in hand, and did not look away. + + +His father took off the man's head with a single sure stroke. Blood sprayed out across the snow, as red as surnmerwine. One of the horses reared and had to be restrained to keep from bolting. Bran could not take his eyes off the blood. The snows around the stump drank it eagerly, reddening as he watched. + +The head bounced off a thick root and rolled. It came up near Greyjoy's feet. Theon was a lean, dark youth of nineteen who found everything amusing. He laughed, put his boot on the head, and kicked it away. + + +"Ass," Jon muttered, low enough so Greyjoy did not hear. He put a hand on Bran's shoulder, and Bran looked over at his bastard brother. "You did well," Jon told him solemnly. Jon was fourteen, an old hand at justice. + + +It seemed colder on the long ride back to Winterfell, though the wind had died by then and the sun was higher in the sky. Bran rode with his brothers, well ahead of the main party, his pony struggling hard to keep up with their horses. + + +"The deserter died bravely," Robb said. He was big and broad and growing every day, with his mother's coloring, the fair skin, red-brown hair, and blue eyes of the Tullys of Riverrun. "He had courage, at the least." + + +"No," Jon Snow said quietly. "It was not courage. This one was dead of fear. You could see it in his eyes, Stark." Jon's eyes were a grey so dark they seemed almost black, but there was little they did not see. He was of an age with Robb, but they did not look alike. Jon was slender where Robb was muscular, dark where Robb was fair, graceful and quick where his half brother was strong and fast. + + +Robb was not impressed. "The Others take his eyes," he swore. "He died well. Race you to the bridge?" + + +"Done," Jon said, kicking his horse forward. Robb cursed and followed, and they galloped off down the trail, Robb laughing and hooting, Jon silent and intent. The hooves of their horses kicked up showers of snow as they went. + +Bran did not try to follow. His pony could not keep up. He had seen the ragged man's eyes, and he was thinking of them now. After a while, the sound of Robb's laughter receded, and the woods grew silent again. + + +So deep in thought was he that he never heard the rest of the party until his father moved up to ride beside him. "Are you well, Bran?" he asked, not unkindly. + + +"Yes, Father," Bran told him. He looked up. Wrapped in his furs and leathers, mounted on his great warhorse, his lord father loomed over him like a giant. "Robb says the man died bravely, but Jon says he was afraid." + + +"What do you think?" his father asked. + + +Bran thought about it. "Can a man still be brave if he's afraid?" + + +"That is the only time a man can be brave," his father told him. "Do you understand why I did it?" + + +"He was a wildling," Bran said. "They carry off women and sell them to the Others." + + +His lord father smiled. "Old Nan has been telling you stories again. In truth, the man was an oathbreaker, a deserter from the Night's Watch. No man is more dangerous. The deserter knows his life is forfeit if he is taken, so he will not flinch from any crime, no matter how vile. But you mistake me. The question was not why the man had to die, but why I must do it." + + +Bran had no answer for that. "King Robert has a headsman," he said, uncertainly. + + +"He does," his father admitted. "As did the Targaryen kings before him. Yet our way is the older way. The blood of the First Men still flows in the veins of the Starks, and we hold to the belief that the man who passes the sentence should swing the sword. If you would take a man's life, you owe it to him to look into his eyes and hear his final words. And if you cannot bear to do that, then perhaps the man does not deserve to die. + + +"One day, Bran, you will be Robb's bannerman, holding a keep of your own for your brother and your king, and justice will fall to you. When that day comes, you must take no pleasure in the task, but neither must you look away. A ruler who hides behind paid executioners soon forgets what death is." + + +That was when Jon reappeared on the crest of the hill before them. He waved and shouted down at them. "Father, Bran, come quickly, see what Robb has found!" Then he was gone again. + + +Jory rode up beside them. "Trouble, my lord?" + + +"Beyond a doubt," his lord father said. "Come, let us see what mischief my sons have rooted out now." He sent his horse into a trot. Jory and Bran and the rest came after. + + +They found Robb on the riverbank north of the bridge, with Jon still mounted beside him. The late summer snows had been heavy this moonturn. Robb stood knee-deep in white, his hood pulled back so the sun shone in his hair. He was cradling something in his arm, while the boys talked in hushed, excited voices. + + +The riders picked their way carefully through the drifts, groping for solid footing on the hidden, uneven ground . Jory Cassel and Theon Greyjoy were the first to reach the boys. Greyjoy was laughing and joking as he rode. Bran heard the breath go out of him. "Gods!" he exclaimed, struggling to keep control of his horse as he reached for his sword. + + +Jory's sword was already out. "Robb, get away from it!" he called as his horse reared under him. + + +Robb grinned and looked up from the bundle in his arms. "She can't hurt you," he said. "She's dead, Jory." + + +Bran was afire with curiosity by then. He would have spurred the pony faster, but his father made them dismount beside the bridge and approach on foot. Bran jumped off and ran. + + +By then Jon, Jory, and Theon Greyjoy had all dismounted as well. "What in the seven hells is it?" Greyjoy was saying. + + +"A wolf," Robb told him. + + +"A freak," Greyjoy said. "Look at the size of it." + + +Bran's heart was thumping in his chest as he pushed through a waist-high drift to his brothers' side. + + +Half-buried in bloodstained snow, a huge dark shape slumped in death. Ice had formed in its shaggy grey fur, and the faint smell of corruption clung to it like a woman's perfume. Bran glimpsed blind eyes crawling with maggots, a wide mouth full of yellowed teeth. But it was the size of it that made him gasp. It was bigger than his pony, twice the size of the largest hound in his father's kennel. + + +"It's no freak," Jon said calmly. "That's a direwolf. They grow larger than the other kind." + + +Theon Greyjoy said, "There's not been a direwolf sighted south of the Wall in two hundred years." + + +"I see one now," Jon replied. + + +Bran tore his eyes away from the monster. That was when he noticed the bundle in Robb's arms. He gave a cry of delight and moved closer. The pup was a tiny ball of grey-black fur, its eyes still closed. It nuzzled blindly against Robb's chest as he cradled it, searching for milk among his leathers, making a sad little whimpery sound. Bran reached out hesitantly. "Go on," Robb told him. "You can touch him." + + +Bran gave the pup a quick nervous stroke, then turned as Jon said, "Here you go." His half brother put a second pup into his arms. "There are five of them." Bran sat down in the snow and hugged the wolf pup to his face. Its fur was soft and warm against his cheek. + + +"Direwolves loose in the realm, after so many years," muttered Hullen, the master of horse. "I like it not." + + +"It is a sign," Jory said. + + +Father frowned. "This is only a dead animal, Jory," he said. Yet he seemed troubled. Snow crunched under his boots as he moved around the body. "Do we know what killed her?" + + +"There's something in the throat," Robb told him, proud to have found the answer before his father even asked. "There, just under the jaw." + + +His father knelt and groped under the beast's head with his hand. He gave a yank and held it up for all to see. A foot of shattered antler, tines snapped off, all wet with blood. + + +A sudden silence descended over the party. The men looked at the antler uneasily, and no one dared to speak. Even Bran could sense their fear, though he did not understand. + + +His father tossed the antler to the side and cleansed his hands in the snow. "I'm surprised she lived long enough to whelp," he said. His voice broke the spell. + + +"Maybe she didn't," Jory said. "I've heard tales . . . maybe the bitch was already dead when the pups came." + + +"Born with the dead," another man put in. "Worse luck." + + +"No matter," said Hullen. "They be dead soon enough too." + + +Bran gave a wordless cry of dismay. + + +"The sooner the better," Theon Greyjoy agreed. He drew his sword. "Give the beast here, Bran." + + +The little thing squirmed against him, as if it heard and understood. "No!" Bran cried out fiercely. "It's mine." + + +"Put away your sword, Greyjoy," Robb said. For a moment he sounded as commanding as their father, like the lord he would someday be. "We will keep these pups." + + +"You cannot do that, boy," said Harwin, who was Hullen's son. + + +"It be a mercy to kill them," Hullen said. + + +Bran looked to his lord father for rescue, but got only a frown, a furrowed brow. "Hullen speaks truly, son. Better a swift death than a hard one from cold and starvation." + + +"No!" He could feel tears welling in his eyes, and he looked away. He did not want to cry in front of his father. + + +Robb resisted stubbornly. "Ser Rodrik's red bitch whelped again last week," he said. "It was a small litter, only two live pups. She'll have milk enough." + + +"She'll rip them apart when they try to nurse." + + +"Lord Stark," Jon said. It was strange to hear him call Father that, so formal. Bran looked at him with desperate hope. "There are five pups," he told Father. "Three male, two female." + + +"What of it, Jon?" + + +"You have five trueborn children," Jon said. "Three sons, two daughters. The direwolf is the sigil of your House. Your children were meant to have these pups, my lord." + + +Bran saw his father's face change, saw the other men exchange glances. He loved Jon with all his heart at that moment. Even at seven, Bran understood what his brother had done. The count had come right only because Jon had omitted himself. He had included the girls, included even Rickon, the baby, but not the bastard who bore the surname Snow, the name that custom decreed be given to all those in the north unlucky enough to be born with no name of their own. + + +Their father understood as well. "You want no pup for yourself, Jon?" he asked softly. + + +"The direwolf graces the banners of House Stark," Jon pointed out. "I am no Stark, Father." + + +Their lord father regarded Jon thoughtfully. Robb rushed into the silence he left. "I will nurse him myself, Father," he promised. "I will soak a towel with warm milk, and give him suck from that." + + +"Me too!" Bran echoed. + + +The lord weighed his sons long and carefully with his eyes. "Easy to say, and harder to do. I will not have you wasting the servants' time with this. If you want these pups, you will feed them yourselves. Is that understood?" + + +Bran nodded eagerly. The pup squirmed in his grasp, licked at his face with a warm tongue. + + +"You must train them as well," their father said. "You must train them. The kennelmaster will have nothing to do with these monsters, I promise you that. And the gods help you if you neglect them, or brutalize them, or train them badly. These are not dogs to beg for treats and slink off at a kick. A direwolf will rip a man's arm off his shoulder as easily as a dog will kill a rat. Are you sure you want this?" + +"Yes, Father," Bran said. + + +"Yes," Robb agreed. + + +"The pups may die anyway, despite all you do." + + +"They won't die," Robb said. "We won't let them die." + + +"Keep them, then. Jory, Desmond, gather up the other pups. It's time we were back to Winterfell." + + +It was not until they were mounted and on their way that Bran allowed himself to taste the sweet air of victory. By then, his pup was snuggled inside his leathers, warm against him, safe for the long ride home. Bran was wondering what to name him. + + +Halfway across the bridge, Jon pulled up suddenly. + + +"What is it, Jon?" their lord father asked. + + +"Can't you hear it?" + + +Bran could hear the wind in the trees, the clatter of their hooves on the ironwood planks, the whimpering of his hungry pup, but Jon was listening to something else. + + +"There," Jon said. He swung his horse around and galloped back across the bridge. They watched him dismount where the direwolf lay dead in the snow, watched him kneel. A moment later he was riding back to them, smiling. + + +"He must have crawled away from the others," Jon said. + + +"Or been driven away," their father said, looking at the sixth pup. His fur was white, where the rest of the litter was grey. His eyes were as red as the blood of the ragged man who had died that morning. Bran thought it curious that this pup alone would have opened his eyes while the others were still blind. + + +"An albino," Theon Greyjoy said with wry amusement. "This one will die even faster than the others." + + +Jon Snow gave his father's ward a long, chilling look. "I think not, Greyjoy," he said. "This one belongs to me." \ No newline at end of file diff --git a/tortoise/data/layman.txt b/tortoise/data/layman.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tortoise/data/mel_norms.pth b/tortoise/data/mel_norms.pth new file mode 100644 index 0000000000000000000000000000000000000000..ed4d6e4f71fba223d920da25f1bbd0c8619433b5 --- /dev/null +++ b/tortoise/data/mel_norms.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1f69422a8a8f344c4fca2f0c6b8d41d2151d6615b7321e48e6bb15ae949b119c +size 1067 diff --git a/tortoise/data/riding_hood.txt b/tortoise/data/riding_hood.txt new file mode 100644 index 0000000000000000000000000000000000000000..2987bef78f92ecb327fc0f754b7ab1211a18542b --- /dev/null +++ b/tortoise/data/riding_hood.txt @@ -0,0 +1,54 @@ +Once upon a time there lived in a certain village a little country girl, the prettiest creature who was ever seen. Her mother was excessively fond of her; and her grandmother doted on her still more. This good woman had a little red riding hood made for her. It suited the girl so extremely well that everybody called her Little Red Riding Hood. +One day her mother, having made some cakes, said to her, "Go, my dear, and see how your grandmother is doing, for I hear she has been very ill. Take her a cake, and this little pot of butter." + +Little Red Riding Hood set out immediately to go to her grandmother, who lived in another village. + +As she was going through the wood, she met with a wolf, who had a very great mind to eat her up, but he dared not, because of some woodcutters working nearby in the forest. He asked her where she was going. The poor child, who did not know that it was dangerous to stay and talk to a wolf, said to him, "I am going to see my grandmother and carry her a cake and a little pot of butter from my mother." + +"Does she live far off?" said the wolf + +"Oh I say," answered Little Red Riding Hood; "it is beyond that mill you see there, at the first house in the village." + +"Well," said the wolf, "and I'll go and see her too. I'll go this way and go you that, and we shall see who will be there first." + +The wolf ran as fast as he could, taking the shortest path, and the little girl took a roundabout way, entertaining herself by gathering nuts, running after butterflies, and gathering bouquets of little flowers. It was not long before the wolf arrived at the old woman's house. He knocked at the door: tap, tap. + +"Who's there?" + +"Your grandchild, Little Red Riding Hood," replied the wolf, counterfeiting her voice; "who has brought you a cake and a little pot of butter sent you by mother." + +The good grandmother, who was in bed, because she was somewhat ill, cried out, "Pull the bobbin, and the latch will go up." + +The wolf pulled the bobbin, and the door opened, and then he immediately fell upon the good woman and ate her up in a moment, for it been more than three days since he had eaten. He then shut the door and got into the grandmother's bed, expecting Little Red Riding Hood, who came some time afterwards and knocked at the door: tap, tap. + +"Who's there?" + +Little Red Riding Hood, hearing the big voice of the wolf, was at first afraid; but believing her grandmother had a cold and was hoarse, answered, "It is your grandchild Little Red Riding Hood, who has brought you a cake and a little pot of butter mother sends you." + +The wolf cried out to her, softening his voice as much as he could, "Pull the bobbin, and the latch will go up." + +Little Red Riding Hood pulled the bobbin, and the door opened. + +The wolf, seeing her come in, said to her, hiding himself under the bedclothes, "Put the cake and the little pot of butter upon the stool, and come get into bed with me." + +Little Red Riding Hood took off her clothes and got into bed. She was greatly amazed to see how her grandmother looked in her nightclothes, and said to her, "Grandmother, what big arms you have!" + +"All the better to hug you with, my dear." + +"Grandmother, what big legs you have!" + +"All the better to run with, my child." + +"Grandmother, what big ears you have!" + +"All the better to hear with, my child." + +"Grandmother, what big eyes you have!" + +"All the better to see with, my child." + +"Grandmother, what big teeth you have got!" + +"All the better to eat you up with." + +And, saying these words, this wicked wolf fell upon Little Red Riding Hood, and ate her all up. \ No newline at end of file diff --git a/tortoise/data/seal_copypasta.txt b/tortoise/data/seal_copypasta.txt new file mode 100644 index 0000000000000000000000000000000000000000..ce59a386070125650d3c6d8e8a13801d3666aa5f --- /dev/null +++ b/tortoise/data/seal_copypasta.txt @@ -0,0 +1 @@ +What the fuck did you just fucking say about me, you little bitch? I'll have you know I graduated top of my class in the Navy Seals, and I've been involved in numerous secret raids on Al kayda, and I have over 300 confirmed kills. I am trained in gorilla warfare and I'm the top sniper in the entire U S armed forces. You are nothing to me but just another target. I will wipe you the fuck out with precision the likes of which has never been seen before on this Earth, mark my fucking words. You think you can get away with saying that shit to me over the Internet? Think again, fucker. As we speak I am contacting my secret network of spies across the U S A and your IP is being traced right now so you better prepare for the storm, maggot. The storm that wipes out the pathetic little thing you call your life. You're fucking dead, kid. I can be anywhere, anytime, and I can kill you in over seven hundred ways, and that's just with my bare hands. Not only am I extensively trained in unarmed combat, but I have access to the entire arsenal of the United States Marine Corps and I will use it to its full extent to wipe your miserable ass off the face of the continent, you little shit. If only you could have known what unholy retribution your little "clever" comment was about to bring down upon you, maybe you would have held your fucking tongue. But you couldn't, you didn't, and now you're paying the price, you goddamn idiot. I will shit fury all over you and you will drown in it. You're fucking dead, kiddo. \ No newline at end of file diff --git a/tortoise/data/tokenizer.json b/tortoise/data/tokenizer.json new file mode 100644 index 0000000000000000000000000000000000000000..a128f273053e465a15c488e48d8106e0c8b0898e --- /dev/null +++ b/tortoise/data/tokenizer.json @@ -0,0 +1 @@ +{"version":"1.0","truncation":null,"padding":null,"added_tokens":[{"id":0,"special":true,"content":"[STOP]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":1,"special":true,"content":"[UNK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":2,"special":true,"content":"[SPACE]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false}],"normalizer":null,"pre_tokenizer":{"type":"Whitespace"},"post_processor":null,"decoder":null,"model":{"type":"BPE","dropout":null,"unk_token":"[UNK]","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"vocab":{"[STOP]":0,"[UNK]":1,"[SPACE]":2,"!":3,"'":4,"(":5,")":6,",":7,"-":8,".":9,"/":10,":":11,";":12,"?":13,"a":14,"b":15,"c":16,"d":17,"e":18,"f":19,"g":20,"h":21,"i":22,"j":23,"k":24,"l":25,"m":26,"n":27,"o":28,"p":29,"q":30,"r":31,"s":32,"t":33,"u":34,"v":35,"w":36,"x":37,"y":38,"z":39,"th":40,"in":41,"the":42,"an":43,"er":44,"ou":45,"re":46,"on":47,"at":48,"ed":49,"en":50,"to":51,"ing":52,"and":53,"is":54,"as":55,"al":56,"or":57,"of":58,"ar":59,"it":60,"es":61,"he":62,"st":63,"le":64,"om":65,"se":66,"be":67,"ad":68,"ow":69,"ly":70,"ch":71,"wh":72,"that":73,"you":74,"li":75,"ve":76,"ac":77,"ti":78,"ld":79,"me":80,"was":81,"gh":82,"id":83,"ll":84,"wi":85,"ent":86,"for":87,"ay":88,"ro":89,"ver":90,"ic":91,"her":92,"ke":93,"his":94,"no":95,"ut":96,"un":97,"ir":98,"lo":99,"we":100,"ri":101,"ha":102,"with":103,"ght":104,"out":105,"im":106,"ion":107,"all":108,"ab":109,"one":110,"ne":111,"ge":112,"ould":113,"ter":114,"mo":115,"had":116,"ce":117,"she":118,"go":119,"sh":120,"ur":121,"am":122,"so":123,"pe":124,"my":125,"de":126,"are":127,"but":128,"ome":129,"fr":130,"ther":131,"fe":132,"su":133,"do":134,"con":135,"te":136,"ain":137,"ere":138,"po":139,"if":140,"they":141,"us":142,"ag":143,"tr":144,"now":145,"oun":146,"this":147,"have":148,"not":149,"sa":150,"il":151,"up":152,"thing":153,"from":154,"ap":155,"him":156,"ack":157,"ation":158,"ant":159,"our":160,"op":161,"like":162,"ust":163,"ess":164,"bo":165,"ok":166,"ul":167,"ind":168,"ex":169,"com":170,"some":171,"there":172,"ers":173,"co":174,"res":175,"man":176,"ard":177,"pl":178,"wor":179,"way":180,"tion":181,"fo":182,"ca":183,"were":184,"by":185,"ate":186,"pro":187,"ted":188,"ound":189,"own":190,"would":191,"ts":192,"what":193,"qu":194,"ally":195,"ight":196,"ck":197,"gr":198,"when":199,"ven":200,"can":201,"ough":202,"ine":203,"end":204,"per":205,"ous":206,"od":207,"ide":208,"know":209,"ty":210,"very":211,"si":212,"ak":213,"who":214,"about":215,"ill":216,"them":217,"est":218,"red":219,"ye":220,"could":221,"ong":222,"your":223,"their":224,"em":225,"just":226,"other":227,"into":228,"any":229,"whi":230,"um":231,"tw":232,"ast":233,"der":234,"did":235,"ie":236,"been":237,"ace":238,"ink":239,"ity":240,"back":241,"ting":242,"br":243,"more":244,"ake":245,"pp":246,"then":247,"sp":248,"el":249,"use":250,"bl":251,"said":252,"over":253,"get":254},"merges":["t h","i n","th e","a n","e r","o u","r e","o n","a t","e d","e n","t o","in g","an d","i s","a s","a l","o r","o f","a r","i t","e s","h e","s t","l e","o m","s e","b e","a d","o w","l y","c h","w h","th at","y ou","l i","v e","a c","t i","l d","m e","w as","g h","i d","l l","w i","en t","f or","a y","r o","v er","i c","h er","k e","h is","n o","u t","u n","i r","l o","w e","r i","h a","wi th","gh t","ou t","i m","i on","al l","a b","on e","n e","g e","ou ld","t er","m o","h ad","c e","s he","g o","s h","u r","a m","s o","p e","m y","d e","a re","b ut","om e","f r","the r","f e","s u","d o","c on","t e","a in","er e","p o","i f","the y","u s","a g","t r","n ow","ou n","th is","ha ve","no t","s a","i l","u p","th ing","fr om","a p","h im","ac k","at ion","an t","ou r","o p","li ke","u st","es s","b o","o k","u l","in d","e x","c om","s ome","the re","er s","c o","re s","m an","ar d","p l","w or","w ay","ti on","f o","c a","w ere","b y","at e","p ro","t ed","oun d","ow n","w ould","t s","wh at","q u","al ly","i ght","c k","g r","wh en","v en","c an","ou gh","in e","en d","p er","ou s","o d","id e","k now","t y","ver y","s i","a k","wh o","ab out","i ll","the m","es t","re d","y e","c ould","on g","you r","the ir","e m","j ust","o ther","in to","an y","wh i","u m","t w","as t","d er","d id","i e","be en","ac e","in k","it y","b ack","t ing","b r","mo re","a ke","p p","the n","s p","e l","u se","b l","sa id","o ver","ge t"]}} \ No newline at end of file diff --git a/tortoise/do_tts.py b/tortoise/do_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..2f0e5623026b965904a251be8d685f148381316f --- /dev/null +++ b/tortoise/do_tts.py @@ -0,0 +1,52 @@ +import argparse +import os + +import torch +import torchaudio + +from api import TextToSpeech, MODELS_DIR +from utils.audio import load_voices + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--text', type=str, help='Text to speak.', default="The expressiveness of autoregressive transformers is literally nuts! I absolutely adore them.") + parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) ' + 'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='random') + parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='fast') + parser.add_argument('--use_deepspeed', type=str, help='Which voice preset to use.', default=False) + parser.add_argument('--kv_cache', type=bool, help='If you disable this please wait for a long a time to get the output', default=True) + parser.add_argument('--half', type=bool, help="float16(half) precision inference if True it's faster and take less vram and ram", default=True) + parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/') + parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this' + 'should only be specified if you have custom checkpoints.', default=MODELS_DIR) + parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice.', default=3) + parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None) + parser.add_argument('--produce_debug_state', type=bool, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True) + parser.add_argument('--cvvp_amount', type=float, help='How much the CVVP model should influence the output.' + 'Increasing this can in some cases reduce the likelihood of multiple speakers. Defaults to 0 (disabled)', default=.0) + args = parser.parse_args() + if torch.backends.mps.is_available(): + args.use_deepspeed = False + os.makedirs(args.output_path, exist_ok=True) + tts = TextToSpeech(models_dir=args.model_dir, use_deepspeed=args.use_deepspeed, kv_cache=args.kv_cache, half=args.half) + + selected_voices = args.voice.split(',') + for k, selected_voice in enumerate(selected_voices): + if '&' in selected_voice: + voice_sel = selected_voice.split('&') + else: + voice_sel = [selected_voice] + voice_samples, conditioning_latents = load_voices(voice_sel) + + gen, dbg_state = tts.tts_with_preset(args.text, k=args.candidates, voice_samples=voice_samples, conditioning_latents=conditioning_latents, + preset=args.preset, use_deterministic_seed=args.seed, return_deterministic_state=True, cvvp_amount=args.cvvp_amount) + if isinstance(gen, list): + for j, g in enumerate(gen): + torchaudio.save(os.path.join(args.output_path, f'{selected_voice}_{k}_{j}.wav'), g.squeeze(0).cpu(), 24000) + else: + torchaudio.save(os.path.join(args.output_path, f'{selected_voice}_{k}.wav'), gen.squeeze(0).cpu(), 24000) + + if args.produce_debug_state: + os.makedirs('debug_states', exist_ok=True) + torch.save(dbg_state, f'debug_states/do_tts_debug_{selected_voice}.pth') + diff --git a/tortoise/eval.py b/tortoise/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..312b162234a927e8e7bd89c26f782487ed67bc07 --- /dev/null +++ b/tortoise/eval.py @@ -0,0 +1,27 @@ +import argparse +import os + +import torchaudio + +from api import TextToSpeech +from tortoise.utils.audio import load_audio + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--eval_path', type=str, help='Path to TSV test file', default="D:\\tmp\\tortoise-tts-eval\\test.tsv") + parser.add_argument('--output_path', type=str, help='Where to put results', default="D:\\tmp\\tortoise-tts-eval\\baseline") + parser.add_argument('--preset', type=str, help='Rendering preset.', default="standard") + args = parser.parse_args() + os.makedirs(args.output_path, exist_ok=True) + + tts = TextToSpeech() + + with open(args.eval_path, 'r', encoding='utf-8') as f: + lines = f.readlines() + + for line in lines: + text, real = line.strip().split('\t') + conds = [load_audio(real, 22050)] + gen = tts.tts_with_preset(text, voice_samples=conds, conditioning_latents=None, preset=args.preset) + torchaudio.save(os.path.join(args.output_path, os.path.basename(real)), gen.squeeze(0).cpu(), 24000) + diff --git a/tortoise/get_conditioning_latents.py b/tortoise/get_conditioning_latents.py new file mode 100644 index 0000000000000000000000000000000000000000..aa7e9b7dde64e4867cfdad025d739ca7fbff425f --- /dev/null +++ b/tortoise/get_conditioning_latents.py @@ -0,0 +1,30 @@ +import argparse +import os +import torch + +from api import TextToSpeech +from tortoise.utils.audio import load_audio, get_voices + +""" +Dumps the conditioning latents for the specified voice to disk. These are expressive latents which can be used for +other ML models, or can be augmented manually and fed back into Tortoise to affect vocal qualities. +""" +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--voice', type=str, help='Selects the voice to convert to conditioning latents', default='pat2') + parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='../results/conditioning_latents') + args = parser.parse_args() + os.makedirs(args.output_path, exist_ok=True) + + tts = TextToSpeech() + voices = get_voices() + selected_voices = args.voice.split(',') + for voice in selected_voices: + cond_paths = voices[voice] + conds = [] + for cond_path in cond_paths: + c = load_audio(cond_path, 22050) + conds.append(c) + conditioning_latents = tts.get_conditioning_latents(conds) + torch.save(conditioning_latents, os.path.join(args.output_path, f'{voice}.pth')) + diff --git a/tortoise/is_this_from_tortoise.py b/tortoise/is_this_from_tortoise.py new file mode 100644 index 0000000000000000000000000000000000000000..289844f499fb45694bfb61f395867b81155daf8b --- /dev/null +++ b/tortoise/is_this_from_tortoise.py @@ -0,0 +1,14 @@ +import argparse + +from api import classify_audio_clip +from tortoise.utils.audio import load_audio + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--clip', type=str, help='Path to an audio clip to classify.', default="../examples/favorite_riding_hood.mp3") + args = parser.parse_args() + + clip = load_audio(args.clip, 24000) + clip = clip[:, :220000] + prob = classify_audio_clip(clip) + print(f"This classifier thinks there is a {prob*100}% chance that this clip was generated from Tortoise.") \ No newline at end of file diff --git a/tortoise/models/__init__.py b/tortoise/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tortoise/models/__pycache__/__init__.cpython-310.pyc b/tortoise/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8874af3e0707aa7079e29a47973d01c9e0d8536a Binary files /dev/null and b/tortoise/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/tortoise/models/__pycache__/arch_util.cpython-310.pyc b/tortoise/models/__pycache__/arch_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c047a90e2775a9b56bb2020c3fa8e58a8233ff6 Binary files /dev/null and b/tortoise/models/__pycache__/arch_util.cpython-310.pyc differ diff --git a/tortoise/models/__pycache__/autoregressive.cpython-310.pyc b/tortoise/models/__pycache__/autoregressive.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33a03db89d8a8973d37a18db6e4ae20f6401d73a Binary files /dev/null and b/tortoise/models/__pycache__/autoregressive.cpython-310.pyc differ diff --git a/tortoise/models/__pycache__/classifier.cpython-310.pyc b/tortoise/models/__pycache__/classifier.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5b882404f9e37f67d9f9908c13778dd4f97273e Binary files /dev/null and b/tortoise/models/__pycache__/classifier.cpython-310.pyc differ diff --git a/tortoise/models/__pycache__/clvp.cpython-310.pyc b/tortoise/models/__pycache__/clvp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14dd043b1544787de2fa7dd13be8e03130a1a44e Binary files /dev/null and b/tortoise/models/__pycache__/clvp.cpython-310.pyc differ diff --git a/tortoise/models/__pycache__/cvvp.cpython-310.pyc b/tortoise/models/__pycache__/cvvp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f2846a663d041f66194c53640dae3aa2c024829 Binary files /dev/null and b/tortoise/models/__pycache__/cvvp.cpython-310.pyc differ diff --git a/tortoise/models/__pycache__/diffusion_decoder.cpython-310.pyc b/tortoise/models/__pycache__/diffusion_decoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3995a0f17e2b273bf87ec9d61fac3905b382a752 Binary files /dev/null and b/tortoise/models/__pycache__/diffusion_decoder.cpython-310.pyc differ diff --git a/tortoise/models/__pycache__/random_latent_generator.cpython-310.pyc b/tortoise/models/__pycache__/random_latent_generator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92ec75c42104a9f0337dd7d1130d460485065460 Binary files /dev/null and b/tortoise/models/__pycache__/random_latent_generator.cpython-310.pyc differ diff --git a/tortoise/models/__pycache__/transformer.cpython-310.pyc b/tortoise/models/__pycache__/transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8525abae4b316b0808b190faa5a17700877b3662 Binary files /dev/null and b/tortoise/models/__pycache__/transformer.cpython-310.pyc differ diff --git a/tortoise/models/__pycache__/vocoder.cpython-310.pyc b/tortoise/models/__pycache__/vocoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..706ba6e028bf5506dd3422a1ba46e37034e08b43 Binary files /dev/null and b/tortoise/models/__pycache__/vocoder.cpython-310.pyc differ diff --git a/tortoise/models/__pycache__/xtransformers.cpython-310.pyc b/tortoise/models/__pycache__/xtransformers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98fcf3178f1c0d1d64f0a4ac3711865beffe55d2 Binary files /dev/null and b/tortoise/models/__pycache__/xtransformers.cpython-310.pyc differ diff --git a/tortoise/models/arch_util.py b/tortoise/models/arch_util.py new file mode 100644 index 0000000000000000000000000000000000000000..f678a0290cc16901b68bb46191a9f7df1001772a --- /dev/null +++ b/tortoise/models/arch_util.py @@ -0,0 +1,373 @@ +import os +import functools +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio +from tortoise.models.xtransformers import ContinuousTransformerWrapper, RelativePositionBias + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def normalization(channels): + """ + Make a standard normalization layer. + + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + groups = 32 + if channels <= 16: + groups = 8 + elif channels <= 64: + groups = 16 + while channels % groups != 0: + groups = int(groups / 2) + assert groups > 2 + return GroupNorm32(groups, channels) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv, mask=None, rel_pos=None): + """ + Apply QKV attention. + + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + if rel_pos is not None: + weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1]) + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + if mask is not None: + # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs. + mask = mask.repeat(self.n_heads, 1).unsqueeze(1) + weight = weight * mask + a = torch.einsum("bts,bcs->bct", weight, v) + + return a.reshape(bs, -1, length) + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + do_checkpoint=True, + relative_pos_embeddings=False, + ): + super().__init__() + self.channels = channels + self.do_checkpoint = do_checkpoint + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.norm = normalization(channels) + self.qkv = nn.Conv1d(channels, channels * 3, 1) + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) + if relative_pos_embeddings: + self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64) + else: + self.relative_pos_embeddings = None + + def forward(self, x, mask=None): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv, mask, self.relative_pos_embeddings) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + """ + + def __init__(self, channels, use_conv, out_channels=None, factor=4): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.factor = factor + if use_conv: + ksize = 5 + pad = 2 + self.conv = nn.Conv1d(self.channels, self.out_channels, ksize, padding=pad) + + def forward(self, x): + assert x.shape[1] == self.channels + x = F.interpolate(x, scale_factor=self.factor, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + """ + + def __init__(self, channels, use_conv, out_channels=None, factor=4, ksize=5, pad=2): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + + stride = factor + if use_conv: + self.op = nn.Conv1d( + self.channels, self.out_channels, ksize, stride=stride, padding=pad + ) + else: + assert self.channels == self.out_channels + self.op = nn.AvgPool1d(kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(nn.Module): + def __init__( + self, + channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + up=False, + down=False, + kernel_size=3, + ): + super().__init__() + self.channels = channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + padding = 1 if kernel_size == 3 else 2 + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False) + self.x_upd = Upsample(channels, False) + elif down: + self.h_upd = Downsample(channels, False) + self.x_upd = Downsample(channels, False) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = nn.Conv1d( + channels, self.out_channels, kernel_size, padding=padding + ) + else: + self.skip_connection = nn.Conv1d(channels, self.out_channels, 1) + + def forward(self, x): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AudioMiniEncoder(nn.Module): + def __init__(self, + spec_dim, + embedding_dim, + base_channels=128, + depth=2, + resnet_blocks=2, + attn_blocks=4, + num_attn_heads=4, + dropout=0, + downsample_factor=2, + kernel_size=3): + super().__init__() + self.init = nn.Sequential( + nn.Conv1d(spec_dim, base_channels, 3, padding=1) + ) + ch = base_channels + res = [] + for l in range(depth): + for r in range(resnet_blocks): + res.append(ResBlock(ch, dropout, kernel_size=kernel_size)) + res.append(Downsample(ch, use_conv=True, out_channels=ch*2, factor=downsample_factor)) + ch *= 2 + self.res = nn.Sequential(*res) + self.final = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.Conv1d(ch, embedding_dim, 1) + ) + attn = [] + for a in range(attn_blocks): + attn.append(AttentionBlock(embedding_dim, num_attn_heads,)) + self.attn = nn.Sequential(*attn) + self.dim = embedding_dim + + def forward(self, x): + h = self.init(x) + h = self.res(h) + h = self.final(h) + h = self.attn(h) + return h[:, :, 0] + + +DEFAULT_MEL_NORM_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data/mel_norms.pth') + + +class TorchMelSpectrogram(nn.Module): + def __init__(self, filter_length=1024, hop_length=256, win_length=1024, n_mel_channels=80, mel_fmin=0, mel_fmax=8000, + sampling_rate=22050, normalize=False, mel_norm_file=DEFAULT_MEL_NORM_FILE): + super().__init__() + # These are the default tacotron values for the MEL spectrogram. + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + self.n_mel_channels = n_mel_channels + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.sampling_rate = sampling_rate + self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length, hop_length=self.hop_length, + win_length=self.win_length, power=2, normalized=normalize, + sample_rate=self.sampling_rate, f_min=self.mel_fmin, + f_max=self.mel_fmax, n_mels=self.n_mel_channels, + norm="slaney") + self.mel_norm_file = mel_norm_file + if self.mel_norm_file is not None: + self.mel_norms = torch.load(self.mel_norm_file) + else: + self.mel_norms = None + + def forward(self, inp): + if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio) + inp = inp.squeeze(1) + assert len(inp.shape) == 2 + if torch.backends.mps.is_available(): + inp = inp.to('cpu') + self.mel_stft = self.mel_stft.to(inp.device) + mel = self.mel_stft(inp) + # Perform dynamic range compression + mel = torch.log(torch.clamp(mel, min=1e-5)) + if self.mel_norms is not None: + self.mel_norms = self.mel_norms.to(mel.device) + mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1) + return mel + + +class CheckpointedLayer(nn.Module): + """ + Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses + checkpoint for all other args. + """ + def __init__(self, wrap): + super().__init__() + self.wrap = wrap + + def forward(self, x, *args, **kwargs): + for k, v in kwargs.items(): + assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing. + partial = functools.partial(self.wrap, **kwargs) + return partial(x, *args) + + +class CheckpointedXTransformerEncoder(nn.Module): + """ + Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid + to channels-last that XTransformer expects. + """ + def __init__(self, needs_permute=True, exit_permute=True, checkpoint=True, **xtransformer_kwargs): + super().__init__() + self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs) + self.needs_permute = needs_permute + self.exit_permute = exit_permute + + if not checkpoint: + return + for i in range(len(self.transformer.attn_layers.layers)): + n, b, r = self.transformer.attn_layers.layers[i] + self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r]) + + def forward(self, x, **kwargs): + if self.needs_permute: + x = x.permute(0,2,1) + h = self.transformer(x, **kwargs) + if self.exit_permute: + h = h.permute(0,2,1) + return h \ No newline at end of file diff --git a/tortoise/models/autoregressive.py b/tortoise/models/autoregressive.py new file mode 100644 index 0000000000000000000000000000000000000000..fcd1a94ff17ee0847048e529581612364c90cbe6 --- /dev/null +++ b/tortoise/models/autoregressive.py @@ -0,0 +1,582 @@ +import functools + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList +from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions +from transformers.utils.model_parallel_utils import get_device_map, assert_device_map +from tortoise.models.arch_util import AttentionBlock +from tortoise.utils.typical_sampling import TypicalLogitsWarper + + +def null_position_embeddings(range, dim): + return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) + + +class ResBlock(nn.Module): + """ + Basic residual convolutional block that uses GroupNorm. + """ + def __init__(self, chan): + super().__init__() + self.net = nn.Sequential( + nn.Conv1d(chan, chan, kernel_size=3, padding=1), + nn.GroupNorm(chan//8, chan), + nn.ReLU(), + nn.Conv1d(chan, chan, kernel_size=3, padding=1), + nn.GroupNorm(chan//8, chan) + ) + + def forward(self, x): + return F.relu(self.net(x) + x) + + +class GPT2InferenceModel(GPT2PreTrainedModel): + def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache=False): + super().__init__(config) + self.transformer = gpt + self.text_pos_embedding = text_pos_emb + self.embeddings = embeddings + self.final_norm = norm + self.lm_head = nn.Sequential(norm, linear) + self.kv_cache = kv_cache + + # Model parallel + self.model_parallel = False + self.device_map = None + self.cached_mel_emb = None + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.transformer.h), range(max(1, torch.cuda.device_count()))) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.model_parallel = True + + def deparallelize(self): + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + if torch.backends.mps.is_available(): + torch.mps.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def store_mel_emb(self, mel_emb): + self.cached_mel_emb = mel_emb + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) # usually None + if not self.kv_cache: + past_key_values = None + # only last token for inputs_ids if past is defined in kwargs + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + assert self.cached_mel_emb is not None + assert inputs_embeds is None # Not supported by this inference model. + assert labels is None # Training not supported by this inference model. + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # Create embedding + mel_len = self.cached_mel_emb.shape[1] + if input_ids.shape[1] != 1: + text_inputs = input_ids[:, mel_len:] + text_emb = self.embeddings(text_inputs) + text_emb = text_emb + self.text_pos_embedding(text_emb) + if self.cached_mel_emb.shape[0] != text_emb.shape[0]: + mel_emb = self.cached_mel_emb.repeat_interleave( + text_emb.shape[0] // self.cached_mel_emb.shape[0], 0 + ) + else: # this outcome only occurs once per loop in most cases + mel_emb = self.cached_mel_emb + emb = torch.cat([mel_emb, text_emb], dim=1) + else: + emb = self.embeddings(input_ids) + emb = emb + self.text_pos_embedding.get_fixed_embedding( + attention_mask.shape[1] - mel_len, attention_mask.device + ) + transformer_outputs = self.transformer( + inputs_embeds=emb, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + if torch.backends.mps.is_available(): + self.to(self.transformer.first_device) + else: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + return (lm_logits,) + transformer_outputs[1:] + + return CausalLMOutputWithCrossAttentions( + loss=None, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache(past, beam_idx): + """ + This function is used to re-order the :obj:`past_key_values` cache if + :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is + called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. + """ + return tuple( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ) + for layer_past in past + ) + + +class ConditioningEncoder(nn.Module): + def __init__(self, + spec_dim, + embedding_dim, + attn_blocks=6, + num_attn_heads=4, + do_checkpointing=False, + mean=False): + super().__init__() + attn = [] + self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1) + for a in range(attn_blocks): + attn.append(AttentionBlock(embedding_dim, num_attn_heads)) + self.attn = nn.Sequential(*attn) + self.dim = embedding_dim + self.do_checkpointing = do_checkpointing + self.mean = mean + + def forward(self, x): + h = self.init(x) + h = self.attn(h) + if self.mean: + return h.mean(dim=2) + else: + return h[:, :, 0] + + +class LearnedPositionEmbeddings(nn.Module): + def __init__(self, seq_len, model_dim, init=.02): + super().__init__() + self.emb = nn.Embedding(seq_len, model_dim) + # Initializing this way is standard for GPT-2 + self.emb.weight.data.normal_(mean=0.0, std=init) + + def forward(self, x): + sl = x.shape[1] + return self.emb(torch.arange(0, sl, device=x.device)) + + def get_fixed_embedding(self, ind, dev): + return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0) + + +def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing): + """ + GPT-2 implemented by the HuggingFace library. + """ + from transformers import GPT2Config, GPT2Model + gpt_config = GPT2Config(vocab_size=256, # Unused. + n_positions=max_mel_seq_len+max_text_seq_len, + n_ctx=max_mel_seq_len+max_text_seq_len, + n_embd=model_dim, + n_layer=layers, + n_head=heads, + gradient_checkpointing=checkpointing, + use_cache=not checkpointing) + gpt = GPT2Model(gpt_config) + # Override the built in positional embeddings + del gpt.wpe + gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) + # Built-in token embeddings are unused. + del gpt.wte + return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim),\ + None, None + + +class MelEncoder(nn.Module): + def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2): + super().__init__() + self.channels = channels + self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=3, padding=1), + nn.Sequential(*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]), + nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(channels//16, channels//2), + nn.ReLU(), + nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]), + nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(channels//8, channels), + nn.ReLU(), + nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]), + ) + self.reduction = 4 + + + def forward(self, x): + for e in self.encoder: + x = e(x) + return x.permute(0,2,1) + + +class UnifiedVoice(nn.Module): + def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1, + mel_length_compression=1024, number_text_tokens=256, + start_text_token=None, number_mel_codes=8194, start_mel_token=8192, + stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True, + checkpointing=True, types=1): + """ + Args: + layers: Number of layers in transformer stack. + model_dim: Operating dimensions of the transformer + heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64 + max_text_tokens: Maximum number of text tokens that will be encountered by model. + max_mel_tokens: Maximum number of MEL tokens that will be encountered by model. + max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s). + mel_length_compression: The factor between and . Used to compute MEL code padding given wav input length. + number_text_tokens: + start_text_token: + stop_text_token: + number_mel_codes: + start_mel_token: + stop_mel_token: + train_solo_embeddings: + use_mel_codes_as_input: + checkpointing: + """ + super().__init__() + + self.number_text_tokens = number_text_tokens + self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token + self.stop_text_token = 0 + self.number_mel_codes = number_mel_codes + self.start_mel_token = start_mel_token + self.stop_mel_token = stop_mel_token + self.layers = layers + self.heads = heads + self.max_mel_tokens = max_mel_tokens + self.max_text_tokens = max_text_tokens + self.model_dim = model_dim + self.max_conditioning_inputs = max_conditioning_inputs + self.mel_length_compression = mel_length_compression + self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) + self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim) + if use_mel_codes_as_input: + self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim) + else: + self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1) + self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ + build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens+2+self.max_conditioning_inputs, self.max_text_tokens+2, checkpointing) + if train_solo_embeddings: + self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True) + self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True) + else: + self.mel_solo_embedding = 0 + self.text_solo_embedding = 0 + + self.final_norm = nn.LayerNorm(model_dim) + self.text_head = nn.Linear(model_dim, self.number_text_tokens*types+1) + self.mel_head = nn.Linear(model_dim, self.number_mel_codes) + + # Initialize the embeddings per the GPT-2 scheme + embeddings = [self.text_embedding] + if use_mel_codes_as_input: + embeddings.append(self.mel_embedding) + for module in embeddings: + module.weight.data.normal_(mean=0.0, std=.02) + def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False, half=False): + seq_length = self.max_mel_tokens + self.max_text_tokens + 2 + gpt_config = GPT2Config( + vocab_size=self.max_mel_tokens, + n_positions=seq_length, + n_ctx=seq_length, + n_embd=self.model_dim, + n_layer=self.layers, + n_head=self.heads, + gradient_checkpointing=False, + use_cache=True, + ) + self.inference_model = GPT2InferenceModel( + gpt_config, + self.gpt, + self.mel_pos_embedding, + self.mel_embedding, + self.final_norm, + self.mel_head, + kv_cache=kv_cache, + ) + if use_deepspeed and half and torch.cuda.is_available(): + import deepspeed + self.ds_engine = deepspeed.init_inference(model=self.inference_model, + mp_size=1, + replace_with_kernel_inject=True, + dtype=torch.float16) + self.inference_model = self.ds_engine.module.eval() + elif use_deepspeed and torch.cuda.is_available(): + import deepspeed + self.ds_engine = deepspeed.init_inference(model=self.inference_model, + mp_size=1, + replace_with_kernel_inject=True, + dtype=torch.float32) + self.inference_model = self.ds_engine.module.eval() + else: + self.inference_model = self.inference_model.eval() + + # self.inference_model = PrunedGPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head) + self.gpt.wte = self.mel_embedding + def build_aligned_inputs_and_targets(self, input, start_token, stop_token): + inp = F.pad(input, (1,0), value=start_token) + tar = F.pad(input, (0,1), value=stop_token) + return inp, tar + + def set_mel_padding(self, mel_input_tokens, wav_lengths): + """ + Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in + that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required + preformatting to create a working TTS model. + """ + # Set padding areas within MEL (currently it is coded with the MEL code for ). + mel_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode='trunc') + for b in range(len(mel_lengths)): + actual_end = mel_lengths[b] + 1 # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token. + if actual_end < mel_input_tokens.shape[-1]: + mel_input_tokens[b, actual_end:] = self.stop_mel_token + return mel_input_tokens + + def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False, return_latent=False): + if second_inputs is not None: + emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1) + else: + emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1) + + gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns) + if get_attns: + return gpt_out.attentions + + enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input + enc = self.final_norm(enc) + + if return_latent: + return enc[:, speech_conditioning_inputs.shape[1]:speech_conditioning_inputs.shape[1]+first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:] + + first_logits = enc[:, :first_inputs.shape[1]] + first_logits = first_head(first_logits) + first_logits = first_logits.permute(0,2,1) + if second_inputs is not None: + second_logits = enc[:, -second_inputs.shape[1]:] + second_logits = second_head(second_logits) + second_logits = second_logits.permute(0,2,1) + return first_logits, second_logits + else: + return first_logits + + def get_conditioning(self, speech_conditioning_input): + speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len( + speech_conditioning_input.shape) == 3 else speech_conditioning_input + conds = [] + for j in range(speech_conditioning_input.shape[1]): + conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds = torch.stack(conds, dim=1) + conds = conds.mean(dim=1) + return conds + + def forward(self, speech_conditioning_latent, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False, + return_latent=False, clip_inputs=True): + """ + Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode + (actuated by `text_first`). + + speech_conditioning_input: MEL float tensor, (b,1024) + text_inputs: long tensor, (b,t) + text_lengths: long tensor, (b,) + mel_inputs: long tensor, (b,m) + wav_lengths: long tensor, (b,) + raw_mels: MEL float tensor (b,80,s) + + If return_attentions is specified, only logits are returned. + If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned. + If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality. + """ + # Types are expressed by expanding the text embedding space. + if types is not None: + text_inputs = text_inputs * (1+types).unsqueeze(-1) + + if clip_inputs: + # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by + # chopping the inputs by the maximum actual length. + max_text_len = text_lengths.max() + text_inputs = text_inputs[:, :max_text_len] + max_mel_len = wav_lengths.max() // self.mel_length_compression + mel_codes = mel_codes[:, :max_mel_len] + if raw_mels is not None: + raw_mels = raw_mels[:, :, :max_mel_len*4] + mel_codes = self.set_mel_padding(mel_codes, wav_lengths) + text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token) + mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token) + + conds = speech_conditioning_latent.unsqueeze(1) + text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) + text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token) + if raw_mels is not None: + mel_inp = F.pad(raw_mels, (0, 8)) + else: + mel_inp = mel_codes + mel_emb = self.mel_embedding(mel_inp) + mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) + + if text_first: + text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent) + if return_latent: + return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. + else: + mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent) + if return_latent: + return text_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. + + if return_attentions: + return mel_logits + loss_text = F.cross_entropy(text_logits, text_targets.long()) + loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) + return loss_text.mean(), loss_mel.mean(), mel_logits + def compute_embeddings( + self, + cond_latents, + text_inputs, + ): + text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) + text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token) + emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + conds = cond_latents.unsqueeze(1) + emb = torch.cat([conds, emb], dim=1) + self.inference_model.store_mel_emb(emb) + gpt_inputs = torch.full( + ( + emb.shape[0], + emb.shape[1] + 1, # +1 for the start_mel_token + ), + fill_value=1, + dtype=torch.long, + device=text_inputs.device, + ) + gpt_inputs[:, -1] = self.start_mel_token + return gpt_inputs + def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1, + max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs): + + text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) + text_inputs, _ = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) + text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + + conds = speech_conditioning_latent.unsqueeze(1) + emb = torch.cat([conds, text_emb], dim=1) + self.inference_model.store_mel_emb(emb) + + fake_inputs = torch.full((emb.shape[0], conds.shape[1] + emb.shape[1],), fill_value=1, dtype=torch.long, + device=text_inputs.device) + fake_inputs[:, -1] = self.start_mel_token + trunc_index = fake_inputs.shape[1] + if input_tokens is None: + inputs = fake_inputs + else: + assert num_return_sequences % input_tokens.shape[0] == 0, "The number of return sequences must be divisible by the number of input sequences" + fake_inputs = fake_inputs.repeat(num_return_sequences, 1) + input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1) + inputs = torch.cat([fake_inputs, input_tokens], dim=1) + + logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList() + max_length = trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length + gen = self.inference_model.generate(inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token, + max_length=max_length, logits_processor=logits_processor, + num_return_sequences=num_return_sequences, **hf_generate_kwargs) + return gen[:, trunc_index:] + + def get_generator(self, fake_inputs, **hf_generate_kwargs): + return self.inference_model.generate_stream( + fake_inputs, + bos_token_id=self.start_mel_token, + pad_token_id=self.stop_mel_token, + eos_token_id=self.stop_mel_token, + max_length=500, + do_stream=True, + **hf_generate_kwargs, + ) +if __name__ == '__main__': + gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4) + l = gpt(torch.randn(2, 3, 80, 800), + torch.randint(high=120, size=(2,120)), + torch.tensor([32, 120]), + torch.randint(high=8192, size=(2,250)), + torch.tensor([250*256,195*256])) + gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80])) diff --git a/tortoise/models/classifier.py b/tortoise/models/classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..f92d99e511d08f8b9e9807fb5ef34e6e871a998c --- /dev/null +++ b/tortoise/models/classifier.py @@ -0,0 +1,148 @@ +import torch +import torch.nn as nn + +from tortoise.models.arch_util import Upsample, Downsample, normalization, zero_module, AttentionBlock + + +class ResBlock(nn.Module): + def __init__( + self, + channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + up=False, + down=False, + kernel_size=3, + do_checkpoint=True, + ): + super().__init__() + self.channels = channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + self.do_checkpoint = do_checkpoint + padding = 1 if kernel_size == 3 else 2 + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = nn.Conv1d( + dims, channels, self.out_channels, kernel_size, padding=padding + ) + else: + self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1) + + def forward(self, x): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AudioMiniEncoder(nn.Module): + def __init__(self, + spec_dim, + embedding_dim, + base_channels=128, + depth=2, + resnet_blocks=2, + attn_blocks=4, + num_attn_heads=4, + dropout=0, + downsample_factor=2, + kernel_size=3): + super().__init__() + self.init = nn.Sequential( + nn.Conv1d(spec_dim, base_channels, 3, padding=1) + ) + ch = base_channels + res = [] + self.layers = depth + for l in range(depth): + for r in range(resnet_blocks): + res.append(ResBlock(ch, dropout, do_checkpoint=False, kernel_size=kernel_size)) + res.append(Downsample(ch, use_conv=True, out_channels=ch*2, factor=downsample_factor)) + ch *= 2 + self.res = nn.Sequential(*res) + self.final = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.Conv1d(ch, embedding_dim, 1) + ) + attn = [] + for a in range(attn_blocks): + attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False)) + self.attn = nn.Sequential(*attn) + self.dim = embedding_dim + + def forward(self, x): + h = self.init(x) + h = self.res(h) + h = self.final(h) + for blk in self.attn: + h = blk(h) + return h[:, :, 0] + + +class AudioMiniEncoderWithClassifierHead(nn.Module): + def __init__(self, classes, distribute_zero_label=True, **kwargs): + super().__init__() + self.enc = AudioMiniEncoder(**kwargs) + self.head = nn.Linear(self.enc.dim, classes) + self.num_classes = classes + self.distribute_zero_label = distribute_zero_label + + def forward(self, x, labels=None): + h = self.enc(x) + logits = self.head(h) + if labels is None: + return logits + else: + if self.distribute_zero_label: + oh_labels = nn.functional.one_hot(labels, num_classes=self.num_classes) + zeros_indices = (labels == 0).unsqueeze(-1) + # Distribute 20% of the probability mass on all classes when zero is specified, to compensate for dataset noise. + zero_extra_mass = torch.full_like(oh_labels, dtype=torch.float, fill_value=.2/(self.num_classes-1)) + zero_extra_mass[:, 0] = -.2 + zero_extra_mass = zero_extra_mass * zeros_indices + oh_labels = oh_labels + zero_extra_mass + else: + oh_labels = labels + loss = nn.functional.cross_entropy(logits, oh_labels) + return loss diff --git a/tortoise/models/clvp.py b/tortoise/models/clvp.py new file mode 100644 index 0000000000000000000000000000000000000000..00f5011a053f28b53a363bcd696e6267c8924c3b --- /dev/null +++ b/tortoise/models/clvp.py @@ -0,0 +1,155 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum + +from tortoise.models.arch_util import CheckpointedXTransformerEncoder +from tortoise.models.transformer import Transformer +from tortoise.models.xtransformers import Encoder + + +def exists(val): + return val is not None + + +def masked_mean(t, mask, dim = 1): + t = t.masked_fill(~mask[:, :, None], 0.) + return t.sum(dim = 1) / mask.sum(dim = 1)[..., None] + +class CLVP(nn.Module): + """ + CLIP model retrofitted for performing contrastive evaluation between tokenized audio data and the corresponding + transcribed text. + + Originally from https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/dalle_pytorch.py + """ + + def __init__( + self, + *, + dim_text=512, + dim_speech=512, + dim_latent=512, + num_text_tokens=256, + text_enc_depth=6, + text_seq_len=120, + text_heads=8, + num_speech_tokens=8192, + speech_enc_depth=6, + speech_heads=8, + speech_seq_len=250, + text_mask_percentage=0, + voice_mask_percentage=0, + wav_token_compression=1024, + use_xformers=False, + ): + super().__init__() + self.text_emb = nn.Embedding(num_text_tokens, dim_text) + self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False) + + self.speech_emb = nn.Embedding(num_speech_tokens, dim_speech) + self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False) + + if use_xformers: + self.text_transformer = CheckpointedXTransformerEncoder( + needs_permute=False, + exit_permute=False, + max_seq_len=-1, + attn_layers=Encoder( + dim=dim_text, + depth=text_enc_depth, + heads=text_heads, + ff_dropout=.1, + ff_mult=2, + attn_dropout=.1, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + )) + self.speech_transformer = CheckpointedXTransformerEncoder( + needs_permute=False, + exit_permute=False, + max_seq_len=-1, + attn_layers=Encoder( + dim=dim_speech, + depth=speech_enc_depth, + heads=speech_heads, + ff_dropout=.1, + ff_mult=2, + attn_dropout=.1, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + )) + else: + self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth, + heads=text_heads) + self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech, + depth=speech_enc_depth, heads=speech_heads) + + self.temperature = nn.Parameter(torch.tensor(1.)) + self.text_mask_percentage = text_mask_percentage + self.voice_mask_percentage = voice_mask_percentage + self.wav_token_compression = wav_token_compression + self.xformers = use_xformers + if not use_xformers: + self.text_pos_emb = nn.Embedding(text_seq_len, dim_text) + self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech) + + def forward( + self, + text, + speech_tokens, + return_loss=False + ): + b, device = text.shape[0], text.device + if self.training: + text_mask = torch.rand_like(text.float()) > self.text_mask_percentage + voice_mask = torch.rand_like(speech_tokens.float()) > self.voice_mask_percentage + else: + text_mask = torch.ones_like(text.float()).bool() + voice_mask = torch.ones_like(speech_tokens.float()).bool() + + text_emb = self.text_emb(text) + speech_emb = self.speech_emb(speech_tokens) + + if not self.xformers: + text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device)) + speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device)) + + enc_text = self.text_transformer(text_emb, mask=text_mask) + enc_speech = self.speech_transformer(speech_emb, mask=voice_mask) + + text_latents = masked_mean(enc_text, text_mask, dim=1) + speech_latents = masked_mean(enc_speech, voice_mask, dim=1) + + text_latents = self.to_text_latent(text_latents) + speech_latents = self.to_speech_latent(speech_latents) + + text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents)) + + temp = self.temperature.exp() + + if not return_loss: + sim = einsum('n d, n d -> n', text_latents, speech_latents) * temp + return sim + + sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp + labels = torch.arange(b, device=device) + loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2 + return loss + + +if __name__ == '__main__': + clip = CLVP(text_mask_percentage=.2, voice_mask_percentage=.2) + clip(torch.randint(0,256,(2,120)), + torch.tensor([50,100]), + torch.randint(0,8192,(2,250)), + torch.tensor([101,102]), + return_loss=True) + nonloss = clip(torch.randint(0,256,(2,120)), + torch.tensor([50,100]), + torch.randint(0,8192,(2,250)), + torch.tensor([101,102]), + return_loss=False) + print(nonloss.shape) \ No newline at end of file diff --git a/tortoise/models/cvvp.py b/tortoise/models/cvvp.py new file mode 100644 index 0000000000000000000000000000000000000000..544ca47b21a31c8d26d4ea407b9783e7d59e8126 --- /dev/null +++ b/tortoise/models/cvvp.py @@ -0,0 +1,142 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum + +from tortoise.models.arch_util import AttentionBlock +from tortoise.models.xtransformers import ContinuousTransformerWrapper, Encoder + + +def exists(val): + return val is not None + + +def masked_mean(t, mask): + t = t.masked_fill(~mask, 0.) + return t.sum(dim=1) / mask.sum(dim=1) + + +class CollapsingTransformer(nn.Module): + def __init__(self, model_dim, output_dims, heads, dropout, depth, mask_percentage=0, **encoder_kwargs): + super().__init__() + self.transformer = ContinuousTransformerWrapper( + max_seq_len=-1, + use_pos_emb=False, + attn_layers=Encoder( + dim=model_dim, + depth=depth, + heads=heads, + ff_dropout=dropout, + ff_mult=1, + attn_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + **encoder_kwargs, + )) + self.pre_combiner = nn.Sequential(nn.Conv1d(model_dim, output_dims, 1), + AttentionBlock( + output_dims, num_heads=heads, do_checkpoint=False), + nn.Conv1d(output_dims, output_dims, 1)) + self.mask_percentage = mask_percentage + + def forward(self, x, **transformer_kwargs): + h = self.transformer(x, **transformer_kwargs) + h = h.permute(0, 2, 1) + h = self.pre_combiner(h).permute(0, 2, 1) + if self.training: + mask = torch.rand_like(h.float()) > self.mask_percentage + else: + mask = torch.ones_like(h.float()).bool() + return masked_mean(h, mask) + + +class ConvFormatEmbedding(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + self.emb = nn.Embedding(*args, **kwargs) + + def forward(self, x): + y = self.emb(x) + return y.permute(0, 2, 1) + + +class CVVP(nn.Module): + def __init__( + self, + model_dim=512, + transformer_heads=8, + dropout=.1, + conditioning_enc_depth=8, + cond_mask_percentage=0, + mel_channels=80, + mel_codes=None, + speech_enc_depth=8, + speech_mask_percentage=0, + latent_multiplier=1, + ): + super().__init__() + latent_dim = latent_multiplier*model_dim + self.temperature = nn.Parameter(torch.tensor(1.)) + + self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2), + nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1)) + self.conditioning_transformer = CollapsingTransformer( + model_dim, model_dim, transformer_heads, dropout, conditioning_enc_depth, cond_mask_percentage) + self.to_conditioning_latent = nn.Linear( + latent_dim, latent_dim, bias=False) + + if mel_codes is None: + self.speech_emb = nn.Conv1d( + mel_channels, model_dim, kernel_size=5, padding=2) + else: + self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim) + self.speech_transformer = CollapsingTransformer( + model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage) + self.to_speech_latent = nn.Linear( + latent_dim, latent_dim, bias=False) + + def get_grad_norm_parameter_groups(self): + return { + 'conditioning': list(self.conditioning_transformer.parameters()), + 'speech': list(self.speech_transformer.parameters()), + } + + def forward( + self, + mel_cond, + mel_input, + return_loss=False + ): + cond_emb = self.cond_emb(mel_cond).permute(0, 2, 1) + enc_cond = self.conditioning_transformer(cond_emb) + cond_latents = self.to_conditioning_latent(enc_cond) + + speech_emb = self.speech_emb(mel_input).permute(0, 2, 1) + enc_speech = self.speech_transformer(speech_emb) + speech_latents = self.to_speech_latent(enc_speech) + + cond_latents, speech_latents = map(lambda t: F.normalize( + t, p=2, dim=-1), (cond_latents, speech_latents)) + temp = self.temperature.exp() + + if not return_loss: + sim = einsum('n d, n d -> n', cond_latents, + speech_latents) * temp + return sim + + sim = einsum('i d, j d -> i j', cond_latents, + speech_latents) * temp + labels = torch.arange( + cond_latents.shape[0], device=mel_input.device) + loss = (F.cross_entropy(sim, labels) + + F.cross_entropy(sim.t(), labels)) / 2 + + return loss + + +if __name__ == '__main__': + clvp = CVVP() + clvp(torch.randn(2, 80, 100), + torch.randn(2, 80, 95), + return_loss=True) diff --git a/tortoise/models/diffusion_decoder.py b/tortoise/models/diffusion_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e969129caa2b6da30e6c364207318e5c270c5405 --- /dev/null +++ b/tortoise/models/diffusion_decoder.py @@ -0,0 +1,336 @@ +import math +import random +from abc import abstractmethod + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import autocast + +from tortoise.models.arch_util import normalization, AttentionBlock + + +def is_latent(t): + return t.dtype == torch.float + + +def is_sequence(t): + return t.dtype == torch.long + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +class TimestepBlock(nn.Module): + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + def forward(self, x, emb): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + else: + x = layer(x) + return x + + +class ResBlock(TimestepBlock): + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + dims=2, + kernel_size=3, + efficient_config=True, + use_scale_shift_norm=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_scale_shift_norm = use_scale_shift_norm + padding = {1: 0, 3: 1, 5: 2}[kernel_size] + eff_kernel = 1 if efficient_config else 3 + eff_padding = 0 if efficient_config else 1 + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding), + ) + + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding) + + def forward(self, x, emb): + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = torch.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class DiffusionLayer(TimestepBlock): + def __init__(self, model_channels, dropout, num_heads): + super().__init__() + self.resblk = ResBlock(model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True) + self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True) + + def forward(self, x, time_emb): + y = self.resblk(x, time_emb) + return self.attn(y) + + +class DiffusionTts(nn.Module): + def __init__( + self, + model_channels=512, + num_layers=8, + in_channels=100, + in_latent_channels=512, + in_tokens=8193, + out_channels=200, # mean and variance + dropout=0, + use_fp16=False, + num_heads=16, + # Parameters for regularization. + layer_drop=.1, + unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. + ): + super().__init__() + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.dropout = dropout + self.num_heads = num_heads + self.unconditioned_percentage = unconditioned_percentage + self.enable_fp16 = use_fp16 + self.layer_drop = layer_drop + + self.inp_block = nn.Conv1d(in_channels, model_channels, 3, 1, 1) + self.time_embed = nn.Sequential( + nn.Linear(model_channels, model_channels), + nn.SiLU(), + nn.Linear(model_channels, model_channels), + ) + + # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed. + # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally + # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive + # transformer network. + self.code_embedding = nn.Embedding(in_tokens, model_channels) + self.code_converter = nn.Sequential( + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + ) + self.code_norm = normalization(model_channels) + self.latent_conditioner = nn.Sequential( + nn.Conv1d(in_latent_channels, model_channels, 3, padding=1), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + ) + self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2), + nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2), + AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), + AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), + AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), + AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), + AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False)) + self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1)) + self.conditioning_timestep_integrator = TimestepEmbedSequential( + DiffusionLayer(model_channels, dropout, num_heads), + DiffusionLayer(model_channels, dropout, num_heads), + DiffusionLayer(model_channels, dropout, num_heads), + ) + + self.integrating_conv = nn.Conv1d(model_channels*2, model_channels, kernel_size=1) + self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1) + + self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] + + [ResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)]) + + self.out = nn.Sequential( + normalization(model_channels), + nn.SiLU(), + nn.Conv1d(model_channels, out_channels, 3, padding=1), + ) + + def get_grad_norm_parameter_groups(self): + groups = { + 'minicoder': list(self.contextual_embedder.parameters()), + 'layers': list(self.layers.parameters()), + 'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.latent_conditioner.parameters()) + list(self.latent_conditioner.parameters()), + 'timestep_integrator': list(self.conditioning_timestep_integrator.parameters()) + list(self.integrating_conv.parameters()), + 'time_embed': list(self.time_embed.parameters()), + } + return groups + + def get_conditioning(self, conditioning_input): + speech_conditioning_input = conditioning_input.unsqueeze(1) if len( + conditioning_input.shape) == 3 else conditioning_input + conds = [] + for j in range(speech_conditioning_input.shape[1]): + conds.append(self.contextual_embedder(speech_conditioning_input[:, j])) + conds = torch.cat(conds, dim=-1) + conds = conds.mean(dim=-1) + return conds + + def timestep_independent(self, aligned_conditioning, conditioning_latent, expected_seq_len, return_code_pred): + # Shuffle aligned_latent to BxCxS format + if is_latent(aligned_conditioning): + aligned_conditioning = aligned_conditioning.permute(0, 2, 1) + + cond_scale, cond_shift = torch.chunk(conditioning_latent, 2, dim=1) + if is_latent(aligned_conditioning): + code_emb = self.latent_conditioner(aligned_conditioning) + else: + code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1) + code_emb = self.code_converter(code_emb) + code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1) + + unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device) + # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. + if self.training and self.unconditioned_percentage > 0: + unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1), + device=code_emb.device) < self.unconditioned_percentage + code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1), + code_emb) + expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode='nearest') + + if not return_code_pred: + return expanded_code_emb + else: + mel_pred = self.mel_head(expanded_code_emb) + # Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss. + mel_pred = mel_pred * unconditioned_batches.logical_not() + return expanded_code_emb, mel_pred + + def forward(self, x, timesteps, aligned_conditioning=None, conditioning_latent=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False): + """ + Apply the model to an input batch. + + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced. + :param conditioning_latent: a pre-computed conditioning latent; see get_conditioning(). + :param precomputed_aligned_embeddings: Embeddings returned from self.timestep_independent() + :param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered. + :return: an [N x C x ...] Tensor of outputs. + """ + assert precomputed_aligned_embeddings is not None or (aligned_conditioning is not None and conditioning_latent is not None) + assert not (return_code_pred and precomputed_aligned_embeddings is not None) # These two are mutually exclusive. + + unused_params = [] + if conditioning_free: + code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) + unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) + unused_params.extend(list(self.latent_conditioner.parameters())) + else: + if precomputed_aligned_embeddings is not None: + code_emb = precomputed_aligned_embeddings + else: + code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_latent, x.shape[-1], True) + if is_latent(aligned_conditioning): + unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) + else: + unused_params.extend(list(self.latent_conditioner.parameters())) + + unused_params.append(self.unconditioned_embedding) + + time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + code_emb = self.conditioning_timestep_integrator(code_emb, time_emb) + x = self.inp_block(x) + x = torch.cat([x, code_emb], dim=1) + x = self.integrating_conv(x) + for i, lyr in enumerate(self.layers): + # Do layer drop where applicable. Do not drop first and last layers. + if self.training and self.layer_drop > 0 and i != 0 and i != (len(self.layers)-1) and random.random() < self.layer_drop: + unused_params.extend(list(lyr.parameters())) + else: + # First and last blocks will have autocast disabled for improved precision. + if not torch.backends.mps.is_available(): + with autocast(x.device.type, enabled=self.enable_fp16 and i != 0): + x = lyr(x, time_emb) + else: + x = lyr(x, time_emb) + + x = x.float() + out = self.out(x) + + # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. + extraneous_addition = 0 + for p in unused_params: + extraneous_addition = extraneous_addition + p.mean() + out = out + extraneous_addition * 0 + + if return_code_pred: + return out, mel_pred + return out + + +if __name__ == '__main__': + clip = torch.randn(2, 100, 400) + aligned_latent = torch.randn(2,388,512) + aligned_sequence = torch.randint(0,8192,(2,100)) + cond = torch.randn(2, 100, 400) + ts = torch.LongTensor([600, 600]) + model = DiffusionTts(512, layer_drop=.3, unconditioned_percentage=.5) + # Test with latent aligned conditioning + #o = model(clip, ts, aligned_latent, cond) + # Test with sequence aligned conditioning + o = model(clip, ts, aligned_sequence, cond) + diff --git a/tortoise/models/hifigan_decoder.py b/tortoise/models/hifigan_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..ae2f627d87f53677135ecfb7907815bd8d48f0f5 --- /dev/null +++ b/tortoise/models/hifigan_decoder.py @@ -0,0 +1,303 @@ +# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py +import torch +from torch import nn +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn import functional as F +from torch.nn.utils import remove_weight_norm, weight_norm + +LRELU_SLOPE = 0.1 + + +def get_padding(k, d): + return int((k * d - d) / 2) + + +class ResBlock1(torch.nn.Module): + """Residual Block Type 1. It has 3 convolutional layers in each convolutional block. + + Network:: + + x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o + |--------------------------------------------------------------------------------------------------| + + + Args: + channels (int): number of hidden channels for the convolutional layers. + kernel_size (int): size of the convolution filter in each layer. + dilations (list): list of dilation value for each conv layer in a block. + """ + + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super().__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + ] + ) + + def forward(self, x): + """ + Args: + x (Tensor): input tensor. + Returns: + Tensor: output tensor. + Shapes: + x: [B, C, T] + """ + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + """Residual Block Type 2. It has 1 convolutional layers in each convolutional block. + + Network:: + + x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o + |---------------------------------------------------| + + + Args: + channels (int): number of hidden channels for the convolutional layers. + kernel_size (int): size of the convolution filter in each layer. + dilations (list): list of dilation value for each conv layer in a block. + """ + + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super().__init__() + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class HifiganGenerator(torch.nn.Module): + def __init__( + self, + in_channels, + out_channels, + resblock_type, + resblock_dilation_sizes, + resblock_kernel_sizes, + upsample_kernel_sizes, + upsample_initial_channel, + upsample_factors, + inference_padding=5, + cond_channels=0, + conv_pre_weight_norm=True, + conv_post_weight_norm=True, + conv_post_bias=True, + ): + r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF) + + Network: + x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o + .. -> zI ---| + resblockN_kNx1 -> zN ---' + + Args: + in_channels (int): number of input tensor channels. + out_channels (int): number of output tensor channels. + resblock_type (str): type of the `ResBlock`. '1' or '2'. + resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`. + resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`. + upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution. + upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2 + for each consecutive upsampling layer. + upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer. + inference_padding (int): constant padding applied to the input at inference time. Defaults to 5. + """ + super().__init__() + self.inference_padding = inference_padding + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_factors) + # initial upsampling layers + self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)) + resblock = ResBlock1 if resblock_type == "1" else ResBlock2 + # upsampling layers + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + # MRF blocks + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + # post convolution layer + self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)) + if cond_channels > 0: + self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1) + + if not conv_pre_weight_norm: + remove_weight_norm(self.conv_pre) + + if not conv_post_weight_norm: + remove_weight_norm(self.conv_post) + + self.device = torch.device('cuda' if torch.cuda.is_available() else'cpu') + if torch.backends.mps.is_available(): + self.device = torch.device('mps') + + def forward(self, x, g=None): + """ + Args: + x (Tensor): feature input tensor. + g (Tensor): global conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + o = self.conv_pre(x) + if hasattr(self, "cond_layer"): + o = o + self.cond_layer(g) + for i in range(self.num_upsamples): + o = F.leaky_relu(o, LRELU_SLOPE) + o = self.ups[i](o) + z_sum = None + for j in range(self.num_kernels): + if z_sum is None: + z_sum = self.resblocks[i * self.num_kernels + j](o) + else: + z_sum += self.resblocks[i * self.num_kernels + j](o) + o = z_sum / self.num_kernels + o = F.leaky_relu(o) + o = self.conv_post(o) + o = torch.tanh(o) + return o + + @torch.no_grad() + def inference(self, c, g=None): + """ + Args: + x (Tensor): conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + # c = c.to(self.conv_pre.weight.device) + # c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") + up_1 = torch.nn.functional.interpolate( + c.transpose(1,2), + scale_factor=[1024 / 256], + mode="linear", + ) + up_2 = torch.nn.functional.interpolate( + up_1, + scale_factor=[24000 / 22050], + mode="linear", + ) + g = g.unsqueeze(0) + return self.forward(up_2.to(self.device), g.transpose(1,2)) + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) diff --git a/tortoise/models/random_latent_generator.py b/tortoise/models/random_latent_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..e90ef2130a47ec52160709877972716352e04c9c --- /dev/null +++ b/tortoise/models/random_latent_generator.py @@ -0,0 +1,55 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): + if bias is not None: + rest_dim = [1] * (input.ndim - bias.ndim - 1) + return ( + F.leaky_relu( + input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope + ) + * scale + ) + else: + return F.leaky_relu(input, negative_slope=0.2) * scale + + +class EqualLinear(nn.Module): + def __init__( + self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1 + ): + super().__init__() + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + else: + self.bias = None + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + return out + + +class RandomLatentConverter(nn.Module): + def __init__(self, channels): + super().__init__() + self.layers = nn.Sequential(*[EqualLinear(channels, channels, lr_mul=.1) for _ in range(5)], + nn.Linear(channels, channels)) + self.channels = channels + + def forward(self, ref): + r = torch.randn(ref.shape[0], self.channels, device=ref.device) + y = self.layers(r) + return y + + +if __name__ == '__main__': + model = RandomLatentConverter(512) + model(torch.randn(5,512)) \ No newline at end of file diff --git a/tortoise/models/stream_generator.py b/tortoise/models/stream_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..a8dd07b1229b40daf9360e420130fa7e1b5df261 --- /dev/null +++ b/tortoise/models/stream_generator.py @@ -0,0 +1,1057 @@ +# Adapted from: https://github.com/LowinLi/transformers-stream-generator + +from transformers import ( + GenerationConfig, + GenerationMixin, + LogitsProcessorList, + StoppingCriteriaList, + DisjunctiveConstraint, + BeamSearchScorer, + PhrasalConstraint, + ConstrainedBeamSearchScorer, + PreTrainedModel, +) +import numpy as np +import random +import warnings +import inspect +from transformers.generation.utils import GenerateOutput, SampleOutput, logger +import torch +from typing import Callable, List, Optional, Union +from torch import nn +import torch.distributed as dist +import copy + + +def setup_seed(seed): + if seed == -1: + return + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + + +class StreamGenerationConfig(GenerationConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.do_stream = kwargs.pop("do_stream", False) + + +class NewGenerationMixin(GenerationMixin): + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[StreamGenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[ + Callable[[int, torch.Tensor], List[int]] + ] = None, + synced_gpus: Optional[bool] = False, + seed=0, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + r""" + + Generates sequences of token ids for models with a language modeling head. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Parameters: + inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): + The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the + method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` + should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of + `input_ids`, `input_values`, `input_features`, or `pixel_values`. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and + `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned + on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful + for constrained generation conditioned on the prefix, as described in [Autoregressive Entity + Retrieval](https://arxiv.org/abs/2010.00904). + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + kwargs: + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder + specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. + + If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GreedySearchDecoderOnlyOutput`], + - [`~generation.SampleDecoderOnlyOutput`], + - [`~generation.BeamSearchDecoderOnlyOutput`], + - [`~generation.BeamSampleDecoderOnlyOutput`] + + If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GreedySearchEncoderDecoderOutput`], + - [`~generation.SampleEncoderDecoderOutput`], + - [`~generation.BeamSearchEncoderDecoderOutput`], + - [`~generation.BeamSampleEncoderDecoderOutput`] + """ + setup_seed(seed) + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call + self._validate_model_class() + + # priority: `generation_config` argument > `model.generation_config` (the default generation config) + if generation_config is None: + # legacy: users may modify the model configuration to control generation -- update the generation config + # model attribute accordingly, if it was created from the model config + if self.generation_config._from_model_config: + new_generation_config = StreamGenerationConfig.from_model_config( + self.config + ) + if new_generation_config != self.generation_config: + warnings.warn( + "You have modified the pretrained model configuration to control generation. This is a" + " deprecated strategy to control generation and will be removed soon, in a future version." + " Please use a generation configuration file (see" + " https://huggingface.co/docs/transformers/main_classes/text_generation)" + ) + self.generation_config = new_generation_config + generation_config = self.generation_config + + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update( + **kwargs + ) # All unused kwargs must be model kwargs + # self._validate_model_kwargs(model_kwargs.copy()) + + # 2. Set generation parameters if not already defined + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) + + if ( + generation_config.pad_token_id is None + and generation_config.eos_token_id is not None + ): + if model_kwargs.get("attention_mask", None) is None: + logger.warning( + "The attention mask and the pad token id were not set. As a consequence, you may observe " + "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." + ) + eos_token_id = generation_config.eos_token_id + if isinstance(eos_token_id, list): + eos_token_id = eos_token_id[0] + logger.warning( + f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation." + ) + generation_config.pad_token_id = eos_token_id + + # 3. Define model inputs + # inputs_tensor has to be defined + # model_input_name is defined if model-specific keyword input is passed + # otherwise model_input_name is None + # all model-specific keyword inputs are removed from `model_kwargs` + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) + batch_size = inputs_tensor.shape[0] + + # 4. Define other model kwargs + model_kwargs["output_attentions"] = generation_config.output_attentions + model_kwargs["output_hidden_states"] = generation_config.output_hidden_states + model_kwargs["use_cache"] = generation_config.use_cache + + accepts_attention_mask = "attention_mask" in set( + inspect.signature(self.forward).parameters.keys() + ) + requires_attention_mask = "encoder_outputs" not in model_kwargs + + if ( + model_kwargs.get("attention_mask", None) is None + and requires_attention_mask + and accepts_attention_mask + ): + model_kwargs[ + "attention_mask" + ] = self._prepare_attention_mask_for_generation( + inputs_tensor, + generation_config.pad_token_id, + generation_config.eos_token_id, + ) + + # decoder-only models should use left-padding for generation + if not self.config.is_encoder_decoder: + if ( + generation_config.pad_token_id is not None + and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) + > 0 + ): + logger.warning( + "A decoder-only architecture is being used, but right-padding was detected! For correct " + "generation results, please set `padding_side='left'` when initializing the tokenizer." + ) + + if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: + # if model is encoder decoder encoder_outputs are created + # and added to `model_kwargs` + model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, model_kwargs, model_input_name + ) + + # 5. Prepare `input_ids` which will be used for auto-regressive generation + if self.config.is_encoder_decoder: + input_ids = self._prepare_decoder_input_ids_for_generation( + batch_size, + decoder_start_token_id=generation_config.decoder_start_token_id, + bos_token_id=generation_config.bos_token_id, + model_kwargs=model_kwargs, + device=inputs_tensor.device, + ) + else: + # if decoder-only then inputs_tensor has to be `input_ids` + input_ids = inputs_tensor + + # 6. Prepare `max_length` depending on other stopping criteria. + input_ids_seq_length = input_ids.shape[-1] + has_default_max_length = ( + kwargs.get("max_length") is None + and generation_config.max_length is not None + ) + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to" + f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the" + " config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif has_default_max_length and generation_config.max_new_tokens is not None: + generation_config.max_length = ( + generation_config.max_new_tokens + input_ids_seq_length + ) + elif ( + not has_default_max_length and generation_config.max_new_tokens is not None + ): + raise ValueError( + "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a" + " limit to the generated output length. Remove one of those arguments. Please refer to the" + " documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" + ) + + if ( + generation_config.min_length is not None + and generation_config.min_length > generation_config.max_length + ): + raise ValueError( + f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than" + f" the maximum length ({generation_config.max_length})" + ) + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = ( + "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + ) + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + + # 7. determine generation mode + is_constraint_gen_mode = ( + generation_config.constraints is not None + or generation_config.force_words_ids is not None + ) + + is_contrastive_search_gen_mode = ( + generation_config.top_k is not None + and generation_config.top_k > 1 + and generation_config.do_sample is False + and generation_config.penalty_alpha is not None + and generation_config.penalty_alpha > 0 + ) + + is_greedy_gen_mode = ( + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is False + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_sample_gen_mode = ( + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is True + and generation_config.do_stream is False + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_sample_gen_stream_mode = ( + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_stream is True + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_beam_gen_mode = ( + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is False + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_beam_sample_gen_mode = ( + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is True + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_group_beam_gen_mode = ( + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups > 1) + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + + if generation_config.num_beam_groups > generation_config.num_beams: + raise ValueError( + "`num_beam_groups` has to be smaller or equal to `num_beams`" + ) + if is_group_beam_gen_mode and generation_config.do_sample is True: + raise ValueError( + "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." + ) + + if self.device.type != input_ids.device.type: + warnings.warn( + "You are calling .generate() with the `input_ids` being on a device type different" + f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" + f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." + " Please make sure that you have put `input_ids` to the" + f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" + " running `.generate()`.", + UserWarning, + ) + # 8. prepare distribution pre_processing samplers + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=inputs_tensor, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + # 9. prepare stopping criteria + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + # 10. go into different generation modes + if is_greedy_gen_mode: + if generation_config.num_return_sequences > 1: + raise ValueError( + f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" + " greedy search." + ) + + # 11. run greedy search + return self.greedy_search( + input_ids, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_contrastive_search_gen_mode: + if generation_config.num_return_sequences > 1: + raise ValueError( + f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" + " contrastive search." + ) + + return self.contrastive_search( + input_ids, + top_k=generation_config.top_k, + penalty_alpha=generation_config.penalty_alpha, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_sample_gen_mode: + # 11. prepare logits warper + logits_warper = self._get_logits_warper(generation_config) + + # 12. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 13. run sample + return self.sample( + input_ids, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + elif is_sample_gen_stream_mode: + # 11. prepare logits warper + logits_warper = self._get_logits_warper(generation_config) + + # 12. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 13. run sample + return self.sample_stream( + input_ids, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + elif is_beam_gen_mode: + if generation_config.num_return_sequences > generation_config.num_beams: + raise ValueError( + "`num_return_sequences` has to be smaller or equal to `num_beams`." + ) + + if stopping_criteria.max_length is None: + raise ValueError( + "`max_length` needs to be a stopping_criteria for now." + ) + + # 11. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + return self.beam_search( + input_ids, + beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_beam_sample_gen_mode: + # 11. prepare logits warper + logits_warper = self._get_logits_warper(generation_config) + + if stopping_criteria.max_length is None: + raise ValueError( + "`max_length` needs to be a stopping_criteria for now." + ) + # 12. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size * generation_config.num_return_sequences, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + ) + + # 13. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams + * generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 14. run beam sample + return self.beam_sample( + input_ids, + beam_scorer, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_group_beam_gen_mode: + if generation_config.num_return_sequences > generation_config.num_beams: + raise ValueError( + "`num_return_sequences` has to be smaller or equal to `num_beams`." + ) + + if generation_config.num_beams % generation_config.num_beam_groups != 0: + raise ValueError( + "`num_beams` should be divisible by `num_beam_groups` for group beam search." + ) + + if stopping_criteria.max_length is None: + raise ValueError( + "`max_length` needs to be a stopping_criteria for now." + ) + + has_default_typical_p = ( + kwargs.get("typical_p") is None and generation_config.typical_p == 1.0 + ) + if not has_default_typical_p: + raise ValueError( + "Decoder argument `typical_p` is not supported with beam groups." + ) + + # 11. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=generation_config.num_beams, + max_length=stopping_criteria.max_length, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + num_beam_groups=generation_config.num_beam_groups, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + return self.group_beam_search( + input_ids, + beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_constraint_gen_mode: + if generation_config.num_return_sequences > generation_config.num_beams: + raise ValueError( + "`num_return_sequences` has to be smaller or equal to `num_beams`." + ) + + if stopping_criteria.max_length is None: + raise ValueError( + "`max_length` needs to be a stopping_criteria for now." + ) + + if generation_config.num_beams <= 1: + raise ValueError( + "`num_beams` needs to be greater than 1 for constrained generation." + ) + + if generation_config.do_sample: + raise ValueError( + "`do_sample` needs to be false for constrained generation." + ) + + if ( + generation_config.num_beam_groups is not None + and generation_config.num_beam_groups > 1 + ): + raise ValueError( + "`num_beam_groups` not supported yet for constrained generation." + ) + + final_constraints = [] + if generation_config.constraints is not None: + final_constraints = generation_config.constraints + + if generation_config.force_words_ids is not None: + + def typeerror(): + raise ValueError( + "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`" + f"of positive integers, but is {generation_config.force_words_ids}." + ) + + if ( + not isinstance(generation_config.force_words_ids, list) + or len(generation_config.force_words_ids) == 0 + ): + typeerror() + + for word_ids in generation_config.force_words_ids: + if isinstance(word_ids[0], list): + if not isinstance(word_ids, list) or len(word_ids) == 0: + typeerror() + if any( + not isinstance(token_ids, list) for token_ids in word_ids + ): + typeerror() + if any( + any( + (not isinstance(token_id, int) or token_id < 0) + for token_id in token_ids + ) + for token_ids in word_ids + ): + typeerror() + + constraint = DisjunctiveConstraint(word_ids) + else: + if not isinstance(word_ids, list) or len(word_ids) == 0: + typeerror() + if any( + (not isinstance(token_id, int) or token_id < 0) + for token_id in word_ids + ): + typeerror() + + constraint = PhrasalConstraint(word_ids) + final_constraints.append(constraint) + + # 11. prepare beam search scorer + constrained_beam_scorer = ConstrainedBeamSearchScorer( + constraints=final_constraints, + batch_size=batch_size, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + return self.constrained_beam_search( + input_ids, + constrained_beam_scorer=constrained_beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + @torch.no_grad() + def sample_stream( + self, + input_ids: torch.LongTensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_warper: Optional[LogitsProcessorList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + **model_kwargs, + ) -> Union[SampleOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + + + In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead. + For an overview of generation strategies and code examples, check the [following + guide](./generation_strategies). + + + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. + max_length (`int`, *optional*, defaults to 20): + **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated + tokens. The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.SampleDecoderOnlyOutput`], [`~generation.SampleEncoderDecoderOutput`] or `torch.LongTensor`: + A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.SampleEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + + Examples: + + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForCausalLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... TopKLogitsWarper, + ... TemperatureLogitsWarper, + ... StoppingCriteriaList, + ... MaxLengthCriteria, + ... ) + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + + >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token + >>> model.config.pad_token_id = model.config.eos_token_id + >>> model.generation_config.pad_token_id = model.config.eos_token_id + + >>> input_prompt = "Today is a beautiful day, and" + >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList( + ... [ + ... MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id), + ... ] + ... ) + >>> # instantiate logits processors + >>> logits_warper = LogitsProcessorList( + ... [ + ... TopKLogitsWarper(50), + ... TemperatureLogitsWarper(0.7), + ... ] + ... ) + + >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) + + >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT + >>> outputs = model.sample( + ... input_ids, + ... logits_processor=logits_processor, + ... logits_warper=logits_warper, + ... stopping_criteria=stopping_criteria, + ... ) + + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the'] + ```""" + # init values + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria( + stopping_criteria, max_length + ) + logits_warper = ( + logits_warper if logits_warper is not None else LogitsProcessorList() + ) + pad_token_id = ( + pad_token_id + if pad_token_id is not None + else self.generation_config.pad_token_id + ) + eos_token_id = ( + eos_token_id + if eos_token_id is not None + else self.generation_config.eos_token_id + ) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + output_scores = ( + output_scores + if output_scores is not None + else self.generation_config.output_scores + ) + output_attentions = ( + output_attentions + if output_attentions is not None + else self.generation_config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.generation_config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + cross_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + decoder_hidden_states = ( + () if (return_dict_in_generate and output_hidden_states) else None + ) + + # keep track of which sequences are already finished + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + + this_peer_finished = False # used by synced_gpus only + # auto-regressive generation + while True: + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor( + 0.0 if this_peer_finished else 1.0 + ).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if synced_gpus and this_peer_finished: + continue # don't waste resources running the code we don't need + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) + if self.config.is_encoder_decoder + else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError( + "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." + ) + next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( + 1 - unfinished_sequences + ) + yield next_tokens, self.final_norm(outputs.hidden_states[-1][:, -1]) + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id is not None: + unfinished_sequences = unfinished_sequences.mul( + (sum(next_tokens != i for i in eos_token_id)).long() + ) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + if not synced_gpus: + break + else: + this_peer_finished = True + + +def init_stream_support(): + """Overload PreTrainedModel for streaming.""" + PreTrainedModel.generate_stream = NewGenerationMixin.generate + PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream + + +if __name__ == "__main__": + from transformers import PreTrainedModel + from transformers import AutoTokenizer, AutoModelForCausalLM + + PreTrainedModel.generate = NewGenerationMixin.generate + PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream + model = AutoModelForCausalLM.from_pretrained( + "bigscience/bloom-560m", torch_dtype=torch.float16 + ) + + tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") + model = model.to("cuda:0") + model = model.eval() + prompt_text = "hello? \n" + input_ids = tokenizer( + prompt_text, return_tensors="pt", add_special_tokens=False + ).input_ids + input_ids = input_ids.to("cuda:0") + + with torch.no_grad(): + result = model.generate( + input_ids, + max_new_tokens=200, + do_sample=True, + top_k=30, + top_p=0.85, + temperature=0.35, + repetition_penalty=1.2, + early_stopping=True, + seed=0, + ) + print(tokenizer.decode(result, skip_special_tokens=True)) + generator = model.generate( + input_ids, + max_new_tokens=200, + do_sample=True, + top_k=30, + top_p=0.85, + temperature=0.35, + repetition_penalty=1.2, + early_stopping=True, + seed=0, + do_stream=True, + ) + stream_result = "" + for x in generator: + chunk = tokenizer.decode(x, skip_special_tokens=True) + stream_result += chunk + print(stream_result) diff --git a/tortoise/models/transformer.py b/tortoise/models/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..707e9ebaea2be706427b8eb663e75ef9d46c5de7 --- /dev/null +++ b/tortoise/models/transformer.py @@ -0,0 +1,219 @@ +from functools import partial + +import torch +import torch.nn.functional as F +from einops import rearrange +from rotary_embedding_torch import RotaryEmbedding, broadcat +from torch import nn + + +# helpers + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +def cast_tuple(val, depth = 1): + if isinstance(val, list): + val = tuple(val) + return val if isinstance(val, tuple) else (val,) * depth + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def stable_softmax(t, dim = -1, alpha = 32 ** 2): + t = t / alpha + t = t - torch.amax(t, dim = dim, keepdim = True).detach() + return (t * alpha).softmax(dim = dim) + + +def route_args(router, args, depth): + routed_args = [(dict(), dict()) for _ in range(depth)] + matched_keys = [key for key in args.keys() if key in router] + + for key in matched_keys: + val = args[key] + for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])): + new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes) + routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) + return routed_args + + +# classes +class SequentialSequence(nn.Module): + def __init__(self, layers, args_route = {}, layer_dropout = 0.): + super().__init__() + assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers' + self.layers = layers + self.args_route = args_route + self.layer_dropout = layer_dropout + + def forward(self, x, **kwargs): + args = route_args(self.args_route, kwargs, len(self.layers)) + layers_and_args = list(zip(self.layers, args)) + + for (f, g), (f_args, g_args) in layers_and_args: + x = x + f(x, **f_args) + x = x + g(x, **g_args) + return x + + +class DivideMax(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + maxes = x.amax(dim = self.dim, keepdim = True).detach() + return x / maxes + + +# https://arxiv.org/abs/2103.17239 +class LayerScale(nn.Module): + def __init__(self, dim, depth, fn): + super().__init__() + if depth <= 18: + init_eps = 0.1 + elif depth > 18 and depth <= 24: + init_eps = 1e-5 + else: + init_eps = 1e-6 + + scale = torch.zeros(1, 1, dim).fill_(init_eps) + self.scale = nn.Parameter(scale) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(x, **kwargs) * self.scale + +# layer norm + + +class PreNorm(nn.Module): + def __init__(self, dim, fn, sandwich = False): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity() + self.fn = fn + + def forward(self, x, **kwargs): + x = self.norm(x) + x = self.fn(x, **kwargs) + return self.norm_out(x) + +# feed forward + + +class GEGLU(nn.Module): + def forward(self, x): + x, gates = x.chunk(2, dim = -1) + return x * F.gelu(gates) + + +class FeedForward(nn.Module): + def __init__(self, dim, dropout = 0., mult = 4.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, dim * mult * 2), + GEGLU(), + nn.Dropout(dropout), + nn.Linear(dim * mult, dim) + ) + + def forward(self, x): + return self.net(x) + +# Attention + + +class Attention(nn.Module): + def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.seq_len = seq_len + self.scale = dim_head ** -0.5 + + self.causal = causal + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x, mask = None): + b, n, _, h, device = *x.shape, self.heads, x.device + softmax = torch.softmax + + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) + + q = q * self.scale + + dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) + mask_value = max_neg_value(dots) + + if exists(mask): + mask = rearrange(mask, 'b j -> b () () j') + dots.masked_fill_(~mask, mask_value) + del mask + + if self.causal: + i, j = dots.shape[-2:] + mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() + dots.masked_fill_(mask, mask_value) + + attn = softmax(dots, dim=-1) + + out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + return out + + +# main transformer class +class Transformer(nn.Module): + def __init__( + self, + *, + dim, + depth, + seq_len, + causal = True, + heads = 8, + dim_head = 64, + ff_mult = 4, + attn_dropout = 0., + ff_dropout = 0., + sparse_attn = False, + sandwich_norm = False, + ): + super().__init__() + layers = nn.ModuleList([]) + sparse_layer = cast_tuple(sparse_attn, depth) + + for ind, sparse_attn in zip(range(depth), sparse_layer): + attn = Attention(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout) + + ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout) + + layers.append(nn.ModuleList([ + LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich = sandwich_norm)), + LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich = sandwich_norm)) + ])) + + execute_type = SequentialSequence + route_attn = ((True, False),) * depth + attn_route_map = {'mask': route_attn} + + self.layers = execute_type(layers, args_route = attn_route_map) + + def forward(self, x, **kwargs): + return self.layers(x, **kwargs) diff --git a/tortoise/models/vocoder.py b/tortoise/models/vocoder.py new file mode 100644 index 0000000000000000000000000000000000000000..8b60dbda152c04e3ca3f0eb649fa617860b9f35b --- /dev/null +++ b/tortoise/models/vocoder.py @@ -0,0 +1,327 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +MAX_WAV_VALUE = 32768.0 + +class KernelPredictor(torch.nn.Module): + ''' Kernel predictor for the location-variable convolutions''' + + def __init__( + self, + cond_channels, + conv_in_channels, + conv_out_channels, + conv_layers, + conv_kernel_size=3, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + kpnet_dropout=0.0, + kpnet_nonlinear_activation="LeakyReLU", + kpnet_nonlinear_activation_params={"negative_slope": 0.1}, + ): + ''' + Args: + cond_channels (int): number of channel for the conditioning sequence, + conv_in_channels (int): number of channel for the input sequence, + conv_out_channels (int): number of channel for the output sequence, + conv_layers (int): number of layers + ''' + super().__init__() + + self.conv_in_channels = conv_in_channels + self.conv_out_channels = conv_out_channels + self.conv_kernel_size = conv_kernel_size + self.conv_layers = conv_layers + + kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w + kpnet_bias_channels = conv_out_channels * conv_layers # l_b + + self.input_conv = nn.Sequential( + nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + ) + + self.residual_convs = nn.ModuleList() + padding = (kpnet_conv_size - 1) // 2 + for _ in range(3): + self.residual_convs.append( + nn.Sequential( + nn.Dropout(kpnet_dropout), + nn.utils.weight_norm( + nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, + bias=True)), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + nn.utils.weight_norm( + nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, + bias=True)), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + ) + ) + self.kernel_conv = nn.utils.weight_norm( + nn.Conv1d(kpnet_hidden_channels, kpnet_kernel_channels, kpnet_conv_size, padding=padding, bias=True)) + self.bias_conv = nn.utils.weight_norm( + nn.Conv1d(kpnet_hidden_channels, kpnet_bias_channels, kpnet_conv_size, padding=padding, bias=True)) + + def forward(self, c): + ''' + Args: + c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) + ''' + batch, _, cond_length = c.shape + c = self.input_conv(c) + for residual_conv in self.residual_convs: + residual_conv.to(c.device) + c = c + residual_conv(c) + k = self.kernel_conv(c) + b = self.bias_conv(c) + kernels = k.contiguous().view( + batch, + self.conv_layers, + self.conv_in_channels, + self.conv_out_channels, + self.conv_kernel_size, + cond_length, + ) + bias = b.contiguous().view( + batch, + self.conv_layers, + self.conv_out_channels, + cond_length, + ) + + return kernels, bias + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.input_conv[0]) + nn.utils.remove_weight_norm(self.kernel_conv) + nn.utils.remove_weight_norm(self.bias_conv) + for block in self.residual_convs: + nn.utils.remove_weight_norm(block[1]) + nn.utils.remove_weight_norm(block[3]) + + +class LVCBlock(torch.nn.Module): + '''the location-variable convolutions''' + + def __init__( + self, + in_channels, + cond_channels, + stride, + dilations=[1, 3, 9, 27], + lReLU_slope=0.2, + conv_kernel_size=3, + cond_hop_length=256, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + kpnet_dropout=0.0, + ): + super().__init__() + + self.cond_hop_length = cond_hop_length + self.conv_layers = len(dilations) + self.conv_kernel_size = conv_kernel_size + + self.kernel_predictor = KernelPredictor( + cond_channels=cond_channels, + conv_in_channels=in_channels, + conv_out_channels=2 * in_channels, + conv_layers=len(dilations), + conv_kernel_size=conv_kernel_size, + kpnet_hidden_channels=kpnet_hidden_channels, + kpnet_conv_size=kpnet_conv_size, + kpnet_dropout=kpnet_dropout, + kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope} + ) + + self.convt_pre = nn.Sequential( + nn.LeakyReLU(lReLU_slope), + nn.utils.weight_norm(nn.ConvTranspose1d(in_channels, in_channels, 2 * stride, stride=stride, + padding=stride // 2 + stride % 2, output_padding=stride % 2)), + ) + + self.conv_blocks = nn.ModuleList() + for dilation in dilations: + self.conv_blocks.append( + nn.Sequential( + nn.LeakyReLU(lReLU_slope), + nn.utils.weight_norm(nn.Conv1d(in_channels, in_channels, conv_kernel_size, + padding=dilation * (conv_kernel_size - 1) // 2, dilation=dilation)), + nn.LeakyReLU(lReLU_slope), + ) + ) + + def forward(self, x, c): + ''' forward propagation of the location-variable convolutions. + Args: + x (Tensor): the input sequence (batch, in_channels, in_length) + c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) + + Returns: + Tensor: the output sequence (batch, in_channels, in_length) + ''' + _, in_channels, _ = x.shape # (B, c_g, L') + + x = self.convt_pre(x) # (B, c_g, stride * L') + kernels, bias = self.kernel_predictor(c) + + for i, conv in enumerate(self.conv_blocks): + output = conv(x) # (B, c_g, stride * L') + + k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length) + b = bias[:, i, :, :] # (B, 2 * c_g, cond_length) + + output = self.location_variable_convolution(output, k, b, + hop_size=self.cond_hop_length) # (B, 2 * c_g, stride * L'): LVC + x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh( + output[:, in_channels:, :]) # (B, c_g, stride * L'): GAU + + return x + + def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256): + ''' perform location-variable convolution operation on the input sequence (x) using the local convolution kernl. + Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100. + Args: + x (Tensor): the input sequence (batch, in_channels, in_length). + kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length) + bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length) + dilation (int): the dilation of convolution. + hop_size (int): the hop_size of the conditioning sequence. + Returns: + (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length). + ''' + batch, _, in_length = x.shape + batch, _, out_channels, kernel_size, kernel_length = kernel.shape + assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched" + + padding = dilation * int((kernel_size - 1) / 2) + x = F.pad(x, (padding, padding), 'constant', 0) # (batch, in_channels, in_length + 2*padding) + x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding) + + if hop_size < dilation: + x = F.pad(x, (0, dilation), 'constant', 0) + x = x.unfold(3, dilation, + dilation) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) + x = x[:, :, :, :, :hop_size] + x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) + x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size) + + o = torch.einsum('bildsk,biokl->bolsd', x, kernel) + o = o.to(memory_format=torch.channels_last_3d) + bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d) + o = o + bias + o = o.contiguous().view(batch, out_channels, -1) + + return o + + def remove_weight_norm(self): + self.kernel_predictor.remove_weight_norm() + nn.utils.remove_weight_norm(self.convt_pre[1]) + for block in self.conv_blocks: + nn.utils.remove_weight_norm(block[1]) + + +class UnivNetGenerator(nn.Module): + """ + UnivNet Generator + + Originally from https://github.com/mindslab-ai/univnet/blob/master/model/generator.py. + """ + + def __init__(self, noise_dim=64, channel_size=32, dilations=[1,3,9,27], strides=[8,8,4], lReLU_slope=.2, kpnet_conv_size=3, + # Below are MEL configurations options that this generator requires. + hop_length=256, n_mel_channels=100): + super(UnivNetGenerator, self).__init__() + self.mel_channel = n_mel_channels + self.noise_dim = noise_dim + self.hop_length = hop_length + channel_size = channel_size + kpnet_conv_size = kpnet_conv_size + + self.res_stack = nn.ModuleList() + hop_length = 1 + for stride in strides: + hop_length = stride * hop_length + self.res_stack.append( + LVCBlock( + channel_size, + n_mel_channels, + stride=stride, + dilations=dilations, + lReLU_slope=lReLU_slope, + cond_hop_length=hop_length, + kpnet_conv_size=kpnet_conv_size + ) + ) + + self.conv_pre = \ + nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode='reflect')) + + self.conv_post = nn.Sequential( + nn.LeakyReLU(lReLU_slope), + nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode='reflect')), + nn.Tanh(), + ) + + def forward(self, c, z): + ''' + Args: + c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length) + z (Tensor): the noise sequence (batch, noise_dim, in_length) + + ''' + z = self.conv_pre(z) # (B, c_g, L) + + for res_block in self.res_stack: + res_block.to(z.device) + z = res_block(z, c) # (B, c_g, L * s_0 * ... * s_i) + + z = self.conv_post(z) # (B, 1, L * 256) + + return z + + def eval(self, inference=False): + super(UnivNetGenerator, self).eval() + # don't remove weight norm while validation in training loop + if inference: + self.remove_weight_norm() + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.conv_pre) + + for layer in self.conv_post: + if len(layer.state_dict()) != 0: + nn.utils.remove_weight_norm(layer) + + for res_block in self.res_stack: + res_block.remove_weight_norm() + + def inference(self, c, z=None): + # pad input mel with zeros to cut artifact + # see https://github.com/seungwonpark/melgan/issues/8 + zero = torch.full((c.shape[0], self.mel_channel, 10), -11.5129).to(c.device) + mel = torch.cat((c, zero), dim=2) + + if z is None: + z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device) + + audio = self.forward(mel, z) + audio = audio[:, :, :-(self.hop_length * 10)] + audio = audio.clamp(min=-1, max=1) + return audio + + +if __name__ == '__main__': + model = UnivNetGenerator() + + c = torch.randn(3, 100, 10) + z = torch.randn(3, 64, 10) + print(c.shape) + + y = model(c, z) + print(y.shape) + assert y.shape == torch.Size([3, 1, 2560]) + + pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(pytorch_total_params) diff --git a/tortoise/models/xtransformers.py b/tortoise/models/xtransformers.py new file mode 100644 index 0000000000000000000000000000000000000000..8be2df455c46bf8c89efb0d5fdbb704a9fb622f6 --- /dev/null +++ b/tortoise/models/xtransformers.py @@ -0,0 +1,1248 @@ +import math +from collections import namedtuple +from functools import partial +from inspect import isfunction + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import nn, einsum + +DEFAULT_DIM_HEAD = 64 + +Intermediates = namedtuple('Intermediates', [ + 'pre_softmax_attn', + 'post_softmax_attn' +]) + +LayerIntermediates = namedtuple('Intermediates', [ + 'hiddens', + 'attn_intermediates', + 'past_key_values', +]) + + +# helpers + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def cast_tuple(val, depth): + return val if isinstance(val, tuple) else (val,) * depth + + +class always(): + def __init__(self, val): + self.val = val + + def __call__(self, *args, **kwargs): + return self.val + + +class not_equals(): + def __init__(self, val): + self.val = val + + def __call__(self, x, *args, **kwargs): + return x != self.val + + +class equals(): + def __init__(self, val): + self.val = val + + def __call__(self, x, *args, **kwargs): + return x == self.val + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +def l2norm(t): + return F.normalize(t, p=2, dim=-1) + + +# init helpers + +def init_zero_(layer): + nn.init.constant_(layer.weight, 0.) + if exists(layer.bias): + nn.init.constant_(layer.bias, 0.) + + +# keyword argument helpers + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + + +def string_begins_with(prefix, str): + return str.startswith(prefix) + + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + return kwargs_without_prefix, kwargs + + +# activations + +class ReluSquared(nn.Module): + def forward(self, x): + return F.relu(x) ** 2 + + +# positional embeddings + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.scale = dim ** -0.5 + self.emb = nn.Embedding(max_seq_len, dim) + + def forward(self, x): + n = torch.arange(x.shape[1], device=x.device) + pos_emb = self.emb(n) + pos_emb = rearrange(pos_emb, 'n d -> () n d') + return pos_emb * self.scale + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x, seq_dim=1, offset=0): + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset + sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return rearrange(emb, 'n d -> () n d') + + +class RelativePositionBias(nn.Module): + def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8): + super().__init__() + self.scale = scale + self.causal = causal + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + if not causal: + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + else: + n = torch.max(n, torch.zeros_like(n)) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, qk_dots): + i, j, device = *qk_dots.shape[-2:], qk_dots.device + q_pos = torch.arange(i, dtype=torch.long, device=device) + k_pos = torch.arange(j, dtype=torch.long, device=device) + rel_pos = k_pos[None, :] - q_pos[:, None] + rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets, + max_distance=self.max_distance) + values = self.relative_attention_bias(rp_bucket) + bias = rearrange(values, 'i j h -> () h i j') + return qk_dots + (bias * self.scale) + + +class AlibiPositionalBias(nn.Module): + def __init__(self, heads, **kwargs): + super().__init__() + self.heads = heads + slopes = torch.Tensor(self._get_slopes(heads)) + slopes = rearrange(slopes, 'h -> () h () ()') + self.register_buffer('slopes', slopes, persistent=False) + self.register_buffer('bias', None, persistent=False) + + @staticmethod + def _get_slopes(heads): + def get_slopes_power_of_2(n): + start = (2 ** (-2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio ** i for i in range(n)] + + if math.log2(heads).is_integer(): + return get_slopes_power_of_2(heads) + + closest_power_of_2 = 2 ** math.floor(math.log2(heads)) + return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][ + :heads - closest_power_of_2] + + def forward(self, qk_dots): + h, i, j, device = *qk_dots.shape[-3:], qk_dots.device + + if exists(self.bias) and self.bias.shape[-1] >= j: + return qk_dots + self.bias[..., :j] + + bias = torch.arange(j, device=device) + bias = rearrange(bias, 'j -> () () () j') + bias = bias * self.slopes + + num_heads_unalibied = h - bias.shape[1] + bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied)) + + self.register_buffer('bias', bias, persistent=False) + return qk_dots + self.bias + + +class LearnedAlibiPositionalBias(AlibiPositionalBias): + def __init__(self, heads, bidirectional=False): + super().__init__(heads) + los_slopes = torch.log(self.slopes) + self.learned_logslopes = nn.Parameter(los_slopes) + + self.bidirectional = bidirectional + if self.bidirectional: + self.learned_logslopes_future = nn.Parameter(los_slopes) + + def forward(self, qk_dots): + h, i, j, device = *qk_dots.shape[-3:], qk_dots.device + + def get_slopes(param): + return F.pad(param.exp(), (0, 0, 0, 0, 0, h - param.shape[1])) + + if exists(self.bias) and self.bias.shape[-1] >= j: + bias = self.bias[..., :i, :j] + else: + i_arange = torch.arange(i, device=device) + j_arange = torch.arange(j, device=device) + bias = rearrange(j_arange, 'j -> 1 1 1 j') - rearrange(i_arange, 'i -> 1 1 i 1') + self.register_buffer('bias', bias, persistent=False) + + if self.bidirectional: + past_slopes = get_slopes(self.learned_logslopes) + future_slopes = get_slopes(self.learned_logslopes_future) + bias = torch.tril(bias * past_slopes) + torch.triu(bias * future_slopes) + else: + slopes = get_slopes(self.learned_logslopes) + bias = bias * slopes + + return qk_dots + bias + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, max_seq_len, device): + t = torch.arange(max_seq_len, device=device).type_as(self.inv_freq) + freqs = torch.einsum('i , j -> i j', t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return rearrange(emb, 'n d -> () () n d') + + +def rotate_half(x): + x = rearrange(x, '... (j d) -> ... j d', j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(t, freqs): + seq_len = t.shape[-2] + freqs = freqs[:, :, -seq_len:] + return (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) + + +# norms + +class Scale(nn.Module): + def __init__(self, value, fn): + super().__init__() + self.value = value + self.fn = fn + + def forward(self, x, **kwargs): + out = self.fn(x, **kwargs) + scale_fn = lambda t: t * self.value + + if not isinstance(out, tuple): + return scale_fn(out) + + return (scale_fn(out[0]), *out[1:]) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x, **kwargs): + out = self.fn(x, **kwargs) + rezero_fn = lambda t: t * self.g + + if not isinstance(out, tuple): + return rezero_fn(out) + + return (rezero_fn(out[0]), *out[1:]) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSScaleShiftNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + self.scale_shift_process = nn.Linear(dim * 2, dim * 2) + + def forward(self, x, norm_scale_shift_inp): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + norm = x / norm.clamp(min=self.eps) * self.g + + ss_emb = self.scale_shift_process(norm_scale_shift_inp) + scale, shift = torch.chunk(ss_emb, 2, dim=1) + h = norm * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + return h + + +# residual and residual gates + +class Residual(nn.Module): + def __init__(self, dim, scale_residual=False): + super().__init__() + self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None + + def forward(self, x, residual): + if exists(self.residual_scale): + residual = residual * self.residual_scale + + return x + residual + + +class GRUGating(nn.Module): + def __init__(self, dim, scale_residual=False): + super().__init__() + self.gru = nn.GRUCell(dim, dim) + self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None + + def forward(self, x, residual): + if exists(self.residual_scale): + residual = residual * self.residual_scale + + gated_output = self.gru( + rearrange(x, 'b n d -> (b n) d'), + rearrange(residual, 'b n d -> (b n) d') + ) + + return gated_output.reshape_as(x) + + +# token shifting + +def shift(t, amount, mask=None): + if amount == 0: + return t + + if exists(mask): + t = t.masked_fill(~mask[..., None], 0.) + + return F.pad(t, (0, 0, amount, -amount), value=0.) + + +class ShiftTokens(nn.Module): + def __init__(self, shifts, fn): + super().__init__() + self.fn = fn + self.shifts = tuple(shifts) + + def forward(self, x, **kwargs): + mask = kwargs.get('mask', None) + shifts = self.shifts + segments = len(shifts) + feats_per_shift = x.shape[-1] // segments + splitted = x.split(feats_per_shift, dim=-1) + segments_to_shift, rest = splitted[:segments], splitted[segments:] + segments_to_shift = list(map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts))) + x = torch.cat((*segments_to_shift, *rest), dim=-1) + return self.fn(x, **kwargs) + + +# feedforward + +class GLU(nn.Module): + def __init__(self, dim_in, dim_out, activation): + super().__init__() + self.act = activation + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * self.act(gate) + + +class FeedForward(nn.Module): + def __init__( + self, + dim, + dim_out=None, + mult=4, + glu=False, + relu_squared=False, + post_act_ln=False, + dropout=0., + zero_init_output=False + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + activation = ReluSquared() if relu_squared else nn.GELU() + + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + activation + ) if not glu else GLU(dim, inner_dim, activation) + + self.net = nn.Sequential( + project_in, + nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(), + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + # init last linear layer to 0 + if zero_init_output: + init_zero_(self.net[-1]) + + def forward(self, x): + return self.net(x) + + +# attention. + +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + talking_heads=False, + head_scale=False, + collab_heads=False, + collab_compression=.3, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0., + on_attn=False, + gate_values=False, + zero_init_output=False, + max_attend_past=None, + qk_norm=False, + scale_init_value=None, + rel_pos_bias=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + ): + super().__init__() + self.scale = dim_head ** -0.5 + + self.heads = heads + self.causal = causal + self.max_attend_past = max_attend_past + + qk_dim = v_dim = dim_head * heads + + # collaborative heads + self.collab_heads = collab_heads + if self.collab_heads: + qk_dim = int(collab_compression * qk_dim) + self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim)) + + self.to_q = nn.Linear(dim, qk_dim, bias=False) + self.to_k = nn.Linear(dim, qk_dim, bias=False) + self.to_v = nn.Linear(dim, v_dim, bias=False) + + self.dropout = nn.Dropout(dropout) + + # add GLU gating for aggregated values, from alphafold2 + self.to_v_gate = None + if gate_values: + self.to_v_gate = nn.Linear(dim, v_dim) + nn.init.constant_(self.to_v_gate.weight, 0) + nn.init.constant_(self.to_v_gate.bias, 1) + + # cosine sim attention + self.qk_norm = qk_norm + if qk_norm: + scale_init_value = default(scale_init_value, + -3) # if not provided, initialize as though it were sequence length of 1024 + self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * scale_init_value) + + # talking heads + self.talking_heads = talking_heads + if talking_heads: + self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + + # head scaling + self.head_scale = head_scale + if head_scale: + self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1)) + + # explicit topk sparse attention + self.sparse_topk = sparse_topk + + # entmax + self.attn_fn = F.softmax + + # add memory key / values + self.num_mem_kv = num_mem_kv + if num_mem_kv > 0: + self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + + # attention on attention + self.attn_on_attn = on_attn + self.to_out = nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(v_dim, dim) + + self.rel_pos_bias = rel_pos_bias + if rel_pos_bias: + assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' + self.rel_pos = RelativePositionBias(scale=dim_head ** 0.5, causal=causal, heads=heads, + num_buckets=rel_pos_num_buckets, max_distance=rel_pos_max_distance) + + # init output projection 0 + if zero_init_output: + init_zero_(self.to_out) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + attn_mask=None, + sinusoidal_emb=None, + rotary_pos_emb=None, + prev_attn=None, + mem=None, + layer_past=None, + ): + b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists( + context) + kv_input = default(context, x) + + q_input = x + k_input = kv_input + v_input = kv_input + + if exists(mem): + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) + + if exists(sinusoidal_emb): + # in shortformer, the query would start at a position offset depending on the past cached memory + offset = k_input.shape[-2] - q_input.shape[-2] + q_input = q_input + sinusoidal_emb(q_input, offset=offset) + k_input = k_input + sinusoidal_emb(k_input) + + q = self.to_q(q_input) + k = self.to_k(k_input) + v = self.to_v(v_input) + + if not collab_heads: + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + else: + q = einsum('b i d, h d -> b h i d', q, self.collab_mixing) + k = rearrange(k, 'b n d -> b () n d') + v = rearrange(v, 'b n (h d) -> b h n d', h=h) + + if layer_past is not None: + past_key, past_value = layer_past + k = torch.cat([past_key, k], dim=-2) + v = torch.cat([past_value, v], dim=-2) + k_cache = k + v_cache = v + + if exists(rotary_pos_emb) and not has_context: + l = rotary_pos_emb.shape[-1] + (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v)) + ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl)) + q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr))) + + input_mask = None + if any(map(exists, (mask, context_mask))): + q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + k_mask = q_mask if not exists(context) else context_mask + k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) + q_mask = rearrange(q_mask, 'b i -> b () i ()') + k_mask = rearrange(k_mask, 'b j -> b () () j') + input_mask = q_mask * k_mask + + if self.num_mem_kv > 0: + mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) + if exists(input_mask): + input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + + if collab_heads: + k = k.expand(-1, h, -1, -1) + + if self.qk_norm: + q, k = map(l2norm, (q, k)) + scale = 1 / (self.scale.exp().clamp(min=1e-2)) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * scale + mask_value = max_neg_value(dots) + + if exists(prev_attn): + dots = dots + prev_attn + + pre_softmax_attn = dots.clone() + + if talking_heads: + dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() + + if self.rel_pos_bias: + dots = self.rel_pos(dots) + + if exists(input_mask): + dots.masked_fill_(~input_mask, mask_value) + del input_mask + + if exists(attn_mask): + assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4' + if attn_mask.ndim == 2: + attn_mask = rearrange(attn_mask, 'i j -> () () i j') + elif attn_mask.ndim == 3: + attn_mask = rearrange(attn_mask, 'h i j -> () h i j') + dots.masked_fill_(~attn_mask, mask_value) + + if exists(self.max_attend_past): + i, j = dots.shape[-2:] + range_q = torch.arange(j - i, j, device=device) + range_k = torch.arange(j, device=device) + dist = rearrange(range_q, 'i -> () () i ()') - rearrange(range_k, 'j -> () () () j') + mask = dist > self.max_attend_past + dots.masked_fill_(mask, mask_value) + del mask + + if self.causal: + i, j = dots.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') + mask = F.pad(mask, (j - i, 0), value=False) + dots.masked_fill_(mask, mask_value) + del mask + + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: + top, _ = dots.topk(self.sparse_topk, dim=-1) + vk = top[..., -1].unsqueeze(-1).expand_as(dots) + mask = dots < vk + dots.masked_fill_(mask, mask_value) + del mask + + attn = self.attn_fn(dots, dim=-1) + post_softmax_attn = attn.clone() + + attn = self.dropout(attn) + + if talking_heads: + attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + + if head_scale: + out = out * self.head_scale_params + + out = rearrange(out, 'b h n d -> b n (h d)') + + if exists(self.to_v_gate): + gates = self.to_v_gate(x) + out = out * gates.sigmoid() + + intermediates = Intermediates( + pre_softmax_attn=pre_softmax_attn, + post_softmax_attn=post_softmax_attn + ) + + return self.to_out(out), intermediates, k_cache, v_cache + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rms_scaleshift_norm=False, + use_rmsnorm=False, + use_rezero=False, + alibi_pos_bias=False, + alibi_num_heads=None, + alibi_learned=False, + position_infused_attn=False, + rotary_pos_emb=False, + rotary_emb_dim=None, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + scale_residual=False, + shift_tokens=0, + sandwich_norm=False, + use_qk_norm_attn=False, + qk_norm_attn_seq_len=None, + zero_init_branch_output=False, + **kwargs + ): + super().__init__() + ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) + attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) + + dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.ModuleList([]) + self.causal = causal + + rel_pos_bias = 'rel_pos_bias' in attn_kwargs + self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb + self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + + rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32) + self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None + + assert not ( + alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both' + + if alibi_pos_bias: + alibi_num_heads = default(alibi_num_heads, heads) + assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads' + alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias + self.rel_pos = alibi_pos_klass(heads=alibi_num_heads, bidirectional=not causal) + else: + self.rel_pos = None + + assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm' + self.pre_norm = pre_norm + self.sandwich_norm = sandwich_norm + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + self.cross_attend = cross_attend + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_class = RMSScaleShiftNorm if use_rms_scaleshift_norm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ('a', 'c', 'f') + elif cross_attend and only_cross: + default_block = ('c', 'f') + else: + default_block = ('a', 'f') + + if macaron: + default_block = ('f',) + default_block + + # qk normalization + + if use_qk_norm_attn: + attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists( + qk_norm_attn_seq_len) else None + attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value} + + # zero init + + if zero_init_branch_output: + attn_kwargs = {**attn_kwargs, 'zero_init_output': True} + ff_kwargs = {**ff_kwargs, 'zero_init_output': True} + + # calculate layer block order + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, 'par ratio out of range' + default_block = tuple(filter(not_equals('f'), default_block)) + par_attn = par_depth // par_ratio + depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert len(default_block) <= par_width, 'default block is too large for par_ratio' + par_block = default_block + ('f',) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ('f',) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' + layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals('a'), layer_types))) + + # calculate token shifting + + shift_tokens = cast_tuple(shift_tokens, len(layer_types)) + + # iterate and construct layers + + for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)): + is_last_layer = ind == (len(self.layer_types) - 1) + + if layer_type == 'a': + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + elif layer_type == 'c': + layer = Attention(dim, heads=heads, **attn_kwargs) + elif layer_type == 'f': + layer = FeedForward(dim, **ff_kwargs) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f'invalid layer type {layer_type}') + + if layer_shift_tokens > 0: + shift_range_upper = layer_shift_tokens + 1 + shift_range_lower = -layer_shift_tokens if not causal else 0 + layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer) + + if exists(branch_fn): + layer = branch_fn(layer) + + residual_fn = GRUGating if gate_residual else Residual + residual = residual_fn(dim, scale_residual=scale_residual) + + layer_uses_qk_norm = use_qk_norm_attn and layer_type in ('a', 'c') + + pre_branch_norm = norm_fn() if pre_norm and not layer_uses_qk_norm else None + post_branch_norm = norm_fn() if sandwich_norm or layer_uses_qk_norm else None + post_main_norm = norm_fn() if not pre_norm and not is_last_layer else None + + norms = nn.ModuleList([ + pre_branch_norm, + post_branch_norm, + post_main_norm + ]) + + self.layers.append(nn.ModuleList([ + norms, + layer, + residual + ])) + + def forward( + self, + x, + context=None, + full_context=None, # for passing a list of hidden states from an encoder + mask=None, + context_mask=None, + attn_mask=None, + mems=None, + return_hiddens=False, + norm_scale_shift_inp=None, + past_key_values=None, + expected_seq_len=None, + ): + + assert not (self.cross_attend ^ (exists(context) or exists( + full_context))), 'context must be passed in if cross_attend is set to True' + assert context is None or full_context is None, 'only one of full_context or context can be provided' + + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + norm_args = {} + if exists(norm_scale_shift_inp): + norm_args['norm_scale_shift_inp'] = norm_scale_shift_inp + + rotary_pos_emb = None + if exists(self.rotary_pos_emb): + if not self.training and self.causal: + assert expected_seq_len is not None, "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`" + elif expected_seq_len is None: + expected_seq_len = 0 + seq_len = x.shape[1] + if past_key_values is not None: + seq_len += past_key_values[0][0].shape[-2] + max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len]) + rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device) + + present_key_values = [] + cross_attn_count = 0 + for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + if layer_type == 'a': + layer_mem = mems.pop(0) if mems else None + + residual = x + + pre_branch_norm, post_branch_norm, post_main_norm = norm + + if exists(pre_branch_norm): + x = pre_branch_norm(x, **norm_args) + + if layer_type == 'a' or layer_type == 'c': + if past_key_values is not None: + layer_kv = past_key_values.pop(0) + layer_past = tuple(s.to(x.device) for s in layer_kv) + else: + layer_past = None + + if layer_type == 'a': + out, inter, k, v = block(x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, + prev_attn, layer_mem, layer_past) + elif layer_type == 'c': + if exists(full_context): + out, inter, k, v = block(x, full_context[cross_attn_count], mask, context_mask, None, None, + None, prev_attn, None, layer_past) + else: + out, inter, k, v = block(x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past) + elif layer_type == 'f': + out = block(x) + + if layer_type == 'a' or layer_type == 'c' and present_key_values is not None: + present_key_values.append((k.detach(), v.detach())) + + if exists(post_branch_norm): + out = post_branch_norm(out, **norm_args) + + x = residual_fn(out, residual) + + if layer_type in ('a', 'c'): + intermediates.append(inter) + + if layer_type == 'a' and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == 'c' and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if exists(post_main_norm): + x = post_main_norm(x, **norm_args) + + if layer_type == 'c': + cross_attn_count += 1 + + if layer_type == 'f': + hiddens.append(x) + + if return_hiddens: + intermediates = LayerIntermediates( + hiddens=hiddens, + attn_intermediates=intermediates, + past_key_values=present_key_values + ) + + return x, intermediates + + return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs): + assert 'causal' not in kwargs, 'cannot set causality on encoder' + super().__init__(causal=False, **kwargs) + + +class Decoder(AttentionLayers): + def __init__(self, **kwargs): + assert 'causal' not in kwargs, 'cannot set causality on decoder' + super().__init__(causal=True, **kwargs) + + +class CrossAttender(AttentionLayers): + def __init__(self, **kwargs): + super().__init__(cross_attend=True, only_cross=True, **kwargs) + + +class ViTransformerWrapper(nn.Module): + def __init__( + self, + *, + image_size, + patch_size, + attn_layers, + num_classes=None, + dropout=0., + emb_dropout=0. + ): + super().__init__() + assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder' + assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size' + dim = attn_layers.dim + num_patches = (image_size // patch_size) ** 2 + patch_dim = 3 * patch_size ** 2 + + self.patch_size = patch_size + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) + self.patch_to_embedding = nn.Linear(patch_dim, dim) + self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) + self.dropout = nn.Dropout(emb_dropout) + + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + self.mlp_head = FeedForward(dim, dim_out=num_classes, dropout=dropout) if exists(num_classes) else None + + def forward( + self, + img, + return_embeddings=False + ): + p = self.patch_size + + x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p) + x = self.patch_to_embedding(x) + b, n, _ = x.shape + + cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embedding[:, :(n + 1)] + x = self.dropout(x) + + x = self.attn_layers(x) + x = self.norm(x) + + if not exists(self.mlp_head) or return_embeddings: + return x + + return self.mlp_head(x[:, 0]) + + +class TransformerWrapper(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0., + shift_mem_down=0, + emb_dropout=0., + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + + dim = attn_layers.dim + emb_dim = default(emb_dim, dim) + + self.max_seq_len = max_seq_len + self.max_mem_len = max_mem_len + self.shift_mem_down = shift_mem_down + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( + use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.init_() + + self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + + # memory tokens (like [cls]) from Memory Transformers paper + num_memory_tokens = default(num_memory_tokens, 0) + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + def init_(self): + nn.init.kaiming_normal_(self.token_emb.weight) + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_hiddens=False, + return_attn=False, + mems=None, + use_cache=False, + **kwargs + ): + b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens + x = self.token_emb(x) + x = x + self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + + if num_mem > 0: + mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) + x = torch.cat((mem, x), dim=1) + + # auto-handle masking after appending memory tokens + if exists(mask): + mask = F.pad(mask, (num_mem, 0), value=True) + + if self.shift_mem_down and exists(mems): + mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:] + mems = [*mems_r, *mems_l] + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + mem, x = x[:, :num_mem], x[:, num_mem:] + + out = self.to_logits(x) if not return_embeddings else x + + if return_hiddens: + hiddens = intermediates.hiddens + return out, hiddens + + res = [out] + if return_attn: + attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + res.append(attn_maps) + if use_cache: + res.append(intermediates.past_key_values) + + if len(res) > 1: + return tuple(res) + return res[0] + + +class ContinuousTransformerWrapper(nn.Module): + def __init__( + self, + *, + max_seq_len, + attn_layers, + dim_in=None, + dim_out=None, + emb_dim=None, + emb_dropout=0., + use_pos_emb=True + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + + dim = attn_layers.dim + + self.max_seq_len = max_seq_len + + self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) if ( + use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity() + + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity() + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_attn=False, + mems=None, + use_cache=False, + **kwargs + ): + b, n, _, device = *x.shape, x.device + + x = self.project_in(x) + x = x + self.pos_emb(x) + x = self.emb_dropout(x) + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + out = self.project_out(x) if not return_embeddings else x + + res = [out] + if return_attn: + attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + res.append(attn_maps) + if use_cache: + res.append(intermediates.past_key_values) + + if len(res) > 1: + return tuple(res) + return res[0] + diff --git a/tortoise/read.py b/tortoise/read.py new file mode 100644 index 0000000000000000000000000000000000000000..e5839aa89522d4770ab3f53ef2aca5b7eb7eac84 --- /dev/null +++ b/tortoise/read.py @@ -0,0 +1,101 @@ +import argparse +import os +from time import time + +import torch +import torchaudio + +from api import TextToSpeech, MODELS_DIR +from utils.audio import load_audio, load_voices +from utils.text import split_and_recombine_text + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--textfile', type=str, help='A file containing the text to read.', default="tortoise/data/riding_hood.txt") + parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) ' + 'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='pat') + parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/longform/') + parser.add_argument('--output_name', type=str, help='How to name the output file', default='combined.wav') + parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard') + parser.add_argument('--regenerate', type=str, help='Comma-separated list of clip numbers to re-generate, or nothing.', default=None) + parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice. Only the first candidate is actually used in the final product, the others can be used manually.', default=1) + parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this' + 'should only be specified if you have custom checkpoints.', default=MODELS_DIR) + parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None) + parser.add_argument('--produce_debug_state', type=bool, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True) + parser.add_argument('--use_deepspeed', type=bool, help='Use deepspeed for speed bump.', default=False) + parser.add_argument('--kv_cache', type=bool, help='If you disable this please wait for a long a time to get the output', default=True) + parser.add_argument('--half', type=bool, help="float16(half) precision inference if True it's faster and take less vram and ram", default=True) + + + args = parser.parse_args() + if torch.backends.mps.is_available(): + args.use_deepspeed = False + tts = TextToSpeech(models_dir=args.model_dir, use_deepspeed=args.use_deepspeed, kv_cache=args.kv_cache, half=args.half) + + outpath = args.output_path + outname = args.output_name + selected_voices = args.voice.split(',') + regenerate = args.regenerate + if regenerate is not None: + regenerate = [int(e) for e in regenerate.split(',')] + + # Process text + with open(args.textfile, 'r', encoding='utf-8') as f: + text = ' '.join([l for l in f.readlines()]) + if '|' in text: + print("Found the '|' character in your text, which I will use as a cue for where to split it up. If this was not" + "your intent, please remove all '|' characters from the input.") + texts = text.split('|') + else: + texts = split_and_recombine_text(text) + + seed = int(time()) if args.seed is None else args.seed + for selected_voice in selected_voices: + voice_outpath = os.path.join(outpath, selected_voice) + os.makedirs(voice_outpath, exist_ok=True) + + if '&' in selected_voice: + voice_sel = selected_voice.split('&') + else: + voice_sel = [selected_voice] + + voice_samples, conditioning_latents = load_voices(voice_sel) + all_parts = [] + for j, text in enumerate(texts): + if regenerate is not None and j not in regenerate: + all_parts.append(load_audio(os.path.join(voice_outpath, f'{j}.wav'), 24000)) + continue + gen = tts.tts_with_preset(text, voice_samples=voice_samples, conditioning_latents=conditioning_latents, + preset=args.preset, k=args.candidates, use_deterministic_seed=seed) + if args.candidates == 1: + audio_ = gen.squeeze(0).cpu() + torchaudio.save(os.path.join(voice_outpath, f'{j}.wav'), audio_, 24000) + else: + candidate_dir = os.path.join(voice_outpath, str(j)) + os.makedirs(candidate_dir, exist_ok=True) + for k, g in enumerate(gen): + torchaudio.save(os.path.join(candidate_dir, f'{k}.wav'), g.squeeze(0).cpu(), 24000) + audio_ = gen[0].squeeze(0).cpu() + all_parts.append(audio_) + + if args.candidates == 1: + full_audio = torch.cat(all_parts, dim=-1) + torchaudio.save(os.path.join(voice_outpath, f"{outname}.wav"), full_audio, 24000) + + if args.produce_debug_state: + os.makedirs('debug_states', exist_ok=True) + dbg_state = (seed, texts, voice_samples, conditioning_latents) + torch.save(dbg_state, f'debug_states/read_debug_{selected_voice}.pth') + + # Combine each candidate's audio clips. + if args.candidates > 1: + audio_clips = [] + for candidate in range(args.candidates): + for line in range(len(texts)): + wav_file = os.path.join(voice_outpath, str(line), f"{candidate}.wav") + audio_clips.append(load_audio(wav_file, 24000)) + audio_clips = torch.cat(audio_clips, dim=-1) + torchaudio.save(os.path.join(voice_outpath, f"{outname}_{candidate:02d}.wav"), audio_clips, 24000) + audio_clips = [] diff --git a/tortoise/read_fast.py b/tortoise/read_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..f2778d4a9a1a020487afee7e8dbf92834e210fbc --- /dev/null +++ b/tortoise/read_fast.py @@ -0,0 +1,77 @@ +import argparse +import os +from time import time + +import torch +import torchaudio + +from api_fast import TextToSpeech, MODELS_DIR +from utils.audio import load_audio, load_voices +from utils.text import split_and_recombine_text + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--textfile', type=str, help='A file containing the text to read.', default="tortoise/data/riding_hood.txt") + parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) ' + 'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='lj') + parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/longform/') + parser.add_argument('--output_name', type=str, help='How to name the output file', default='combined.wav') + parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard') + parser.add_argument('--regenerate', type=str, help='Comma-separated list of clip numbers to re-generate, or nothing.', default=None) + parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this' + 'should only be specified if you have custom checkpoints.', default=MODELS_DIR) + parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None) + parser.add_argument('--use_deepspeed', type=bool, help='Use deepspeed for speed bump.', default=False) + parser.add_argument('--kv_cache', type=bool, help='If you disable this please wait for a long a time to get the output', default=True) + parser.add_argument('--half', type=bool, help="float16(half) precision inference if True it's faster and take less vram and ram", default=True) + + + args = parser.parse_args() + if torch.backends.mps.is_available(): + args.use_deepspeed = False + tts = TextToSpeech(models_dir=args.model_dir, use_deepspeed=args.use_deepspeed, kv_cache=args.kv_cache, half=args.half) + + outpath = args.output_path + outname = args.output_name + selected_voices = args.voice.split(',') + regenerate = args.regenerate + if regenerate is not None: + regenerate = [int(e) for e in regenerate.split(',')] + + # Process text + with open(args.textfile, 'r', encoding='utf-8') as f: + text = ' '.join([l for l in f.readlines()]) + if '|' in text: + print("Found the '|' character in your text, which I will use as a cue for where to split it up. If this was not" + "your intent, please remove all '|' characters from the input.") + texts = text.split('|') + else: + texts = split_and_recombine_text(text) + + seed = int(time()) if args.seed is None else args.seed + for selected_voice in selected_voices: + voice_outpath = os.path.join(outpath, selected_voice) + os.makedirs(voice_outpath, exist_ok=True) + + if '&' in selected_voice: + voice_sel = selected_voice.split('&') + else: + voice_sel = [selected_voice] + + voice_samples, conditioning_latents = load_voices(voice_sel) + all_parts = [] + for j, text in enumerate(texts): + if regenerate is not None and j not in regenerate: + all_parts.append(load_audio(os.path.join(voice_outpath, f'{j}.wav'), 24000)) + continue + start_time = time() + gen = tts.tts(text, voice_samples=voice_samples, use_deterministic_seed=seed) + end_time = time() + audio_ = gen.squeeze(0).cpu() + print("Time taken to generate the audio: ", end_time - start_time, "seconds") + print("RTF: ", (end_time - start_time) / (audio_.shape[1] / 24000)) + torchaudio.save(os.path.join(voice_outpath, f'{j}.wav'), audio_, 24000) + all_parts.append(audio_) + full_audio = torch.cat(all_parts, dim=-1) + torchaudio.save(os.path.join(voice_outpath, f"{outname}.wav"), full_audio, 24000) diff --git a/tortoise/tts_stream.py b/tortoise/tts_stream.py new file mode 100644 index 0000000000000000000000000000000000000000..94eaff573493e7842bdad47145429f049fec2693 --- /dev/null +++ b/tortoise/tts_stream.py @@ -0,0 +1,85 @@ +import argparse +import os +from time import time + +import torch +import torchaudio + +from api_fast import TextToSpeech, MODELS_DIR +from utils.audio import load_audio, load_voices +from utils.text import split_and_recombine_text +import sounddevice as sd +import queue +import threading +def play_audio(audio_queue): + while True: + chunk = audio_queue.get() + if chunk is None: + break + sd.play(chunk.cpu().numpy(), samplerate=24000) + sd.wait() + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--textfile', type=str, help='A file containing the text to read.', default="tortoise/data/riding_hood.txt") + parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) ' + 'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='lj') + parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/longform/') + parser.add_argument('--output_name', type=str, help='How to name the output file', default='combined.wav') + parser.add_argument('--preset', type=str, help='Which voice preset to use.', default='standard') + parser.add_argument('--regenerate', type=str, help='Comma-separated list of clip numbers to re-generate, or nothing.', default=None) + parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this' + 'should only be specified if you have custom checkpoints.', default=MODELS_DIR) + parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None) + parser.add_argument('--use_deepspeed', type=bool, help='Use deepspeed for speed bump.', default=False) + parser.add_argument('--kv_cache', type=bool, help='If you disable this please wait for a long a time to get the output', default=True) + parser.add_argument('--half', type=bool, help="float16(half) precision inference if True it's faster and take less vram and ram", default=True) + + + args = parser.parse_args() + if torch.backends.mps.is_available(): + args.use_deepspeed = False + tts = TextToSpeech(models_dir=args.model_dir, use_deepspeed=args.use_deepspeed, kv_cache=args.kv_cache, half=args.half) + + outpath = args.output_path + outname = args.output_name + selected_voices = args.voice.split(',') + regenerate = args.regenerate + if regenerate is not None: + regenerate = [int(e) for e in regenerate.split(',')] + + # Process text + with open(args.textfile, 'r', encoding='utf-8') as f: + text = ' '.join([l for l in f.readlines()]) + if '|' in text: + print("Found the '|' character in your text, which I will use as a cue for where to split it up. If this was not" + "your intent, please remove all '|' characters from the input.") + texts = text.split('|') + else: + texts = split_and_recombine_text(text) + audio_queue = queue.Queue() + playback_thread = threading.Thread(target=play_audio, args=(audio_queue,)) + playback_thread.start() + + seed = int(time()) if args.seed is None else args.seed + for selected_voice in selected_voices: + voice_outpath = os.path.join(outpath, selected_voice) + os.makedirs(voice_outpath, exist_ok=True) + + if '&' in selected_voice: + voice_sel = selected_voice.split('&') + else: + voice_sel = [selected_voice] + + voice_samples, conditioning_latents = load_voices(voice_sel) + all_parts = [] + for j, text in enumerate(texts): + if regenerate is not None and j not in regenerate: + all_parts.append(load_audio(os.path.join(voice_outpath, f'{j}.wav'), 24000)) + continue + start_time = time() + audio_generator = tts.tts_stream(text, voice_samples=voice_samples, use_deterministic_seed=seed) + for wav_chunk in audio_generator: + audio_queue.put(wav_chunk) + audio_queue.put(None) + playback_thread.join() \ No newline at end of file diff --git a/tortoise/utils/__init__.py b/tortoise/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tortoise/utils/__pycache__/__init__.cpython-310.pyc b/tortoise/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2dae0a46a7dc92f7ddb2425f5b16214c03fcd59f Binary files /dev/null and b/tortoise/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/tortoise/utils/__pycache__/audio.cpython-310.pyc b/tortoise/utils/__pycache__/audio.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..354e093866f8e6fc4f7bf4d26bb5079927a0bd44 Binary files /dev/null and b/tortoise/utils/__pycache__/audio.cpython-310.pyc differ diff --git a/tortoise/utils/__pycache__/diffusion.cpython-310.pyc b/tortoise/utils/__pycache__/diffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ce60104cf2aa2d914511600a7714774228a946d Binary files /dev/null and b/tortoise/utils/__pycache__/diffusion.cpython-310.pyc differ diff --git a/tortoise/utils/__pycache__/stft.cpython-310.pyc b/tortoise/utils/__pycache__/stft.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f1d977b584cca095da1e098f5cc5f4cba0d7e7a Binary files /dev/null and b/tortoise/utils/__pycache__/stft.cpython-310.pyc differ diff --git a/tortoise/utils/__pycache__/tokenizer.cpython-310.pyc b/tortoise/utils/__pycache__/tokenizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b84890c465cee8b54000df17ed147b768659caf7 Binary files /dev/null and b/tortoise/utils/__pycache__/tokenizer.cpython-310.pyc differ diff --git a/tortoise/utils/__pycache__/typical_sampling.cpython-310.pyc b/tortoise/utils/__pycache__/typical_sampling.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5502d85de04831e4847ade22396e9b89f9a6ca0 Binary files /dev/null and b/tortoise/utils/__pycache__/typical_sampling.cpython-310.pyc differ diff --git a/tortoise/utils/__pycache__/wav2vec_alignment.cpython-310.pyc b/tortoise/utils/__pycache__/wav2vec_alignment.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce596d382d4eec6a502da1a7231b042c81e4ca8b Binary files /dev/null and b/tortoise/utils/__pycache__/wav2vec_alignment.cpython-310.pyc differ diff --git a/tortoise/utils/audio.py b/tortoise/utils/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..e2634357b970e5b129d416abe7325a9ce5b5826f --- /dev/null +++ b/tortoise/utils/audio.py @@ -0,0 +1,190 @@ +import os +from glob import glob + +import librosa +import torch +import torchaudio +import numpy as np +from scipy.io.wavfile import read + +from tortoise.utils.stft import STFT + + +BUILTIN_VOICES_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../voices') + + +def load_wav_to_torch(full_path): + sampling_rate, data = read(full_path) + if data.dtype == np.int32: + norm_fix = 2 ** 31 + elif data.dtype == np.int16: + norm_fix = 2 ** 15 + elif data.dtype == np.float16 or data.dtype == np.float32: + norm_fix = 1. + else: + raise NotImplemented(f"Provided data dtype not supported: {data.dtype}") + return (torch.FloatTensor(data.astype(np.float32)) / norm_fix, sampling_rate) + + +def load_audio(audiopath, sampling_rate): + extension = os.path.splitext(audiopath)[1].casefold() + if extension == '.wav': + audio, lsr = load_wav_to_torch(audiopath) + elif extension == '.mp3': + audio, lsr = librosa.load(audiopath, sr=sampling_rate) + audio = torch.FloatTensor(audio) + else: + assert False, f"Unsupported audio format provided: {audiopath[-4:]}" + + # Remove any channel data. + if len(audio.shape) > 1: + if audio.shape[0] < 5: + audio = audio[0] + else: + assert audio.shape[1] < 5 + audio = audio[:, 0] + + if lsr != sampling_rate: + audio = torchaudio.functional.resample(audio, lsr, sampling_rate) + + # Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk. + # '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds. + if torch.any(audio > 2) or not torch.any(audio < 0): + print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}") + audio.clip_(-1, 1) + + return audio.unsqueeze(0) + + +TACOTRON_MEL_MAX = 2.3143386840820312 +TACOTRON_MEL_MIN = -11.512925148010254 + + +def denormalize_tacotron_mel(norm_mel): + return ((norm_mel+1)/2)*(TACOTRON_MEL_MAX-TACOTRON_MEL_MIN)+TACOTRON_MEL_MIN + + +def normalize_tacotron_mel(mel): + return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1 + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C + + +def get_voices(extra_voice_dirs=[]): + dirs = [BUILTIN_VOICES_DIR] + extra_voice_dirs + voices = {} + for d in dirs: + subs = os.listdir(d) + for sub in subs: + subj = os.path.join(d, sub) + if os.path.isdir(subj): + voices[sub] = list(glob(f'{subj}/*.wav')) + list(glob(f'{subj}/*.mp3')) + list(glob(f'{subj}/*.pth')) + return voices + + +def load_voice(voice, extra_voice_dirs=[]): + if voice == 'random': + return None, None + + voices = get_voices(extra_voice_dirs) + paths = voices[voice] + if len(paths) == 1 and paths[0].endswith('.pth'): + return None, torch.load(paths[0]) + else: + conds = [] + for cond_path in paths: + c = load_audio(cond_path, 22050) + conds.append(c) + return conds, None + + +def load_voices(voices, extra_voice_dirs=[]): + latents = [] + clips = [] + for voice in voices: + if voice == 'random': + if len(voices) > 1: + print("Cannot combine a random voice with a non-random voice. Just using a random voice.") + return None, None + clip, latent = load_voice(voice, extra_voice_dirs) + if latent is None: + assert len(latents) == 0, "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this." + clips.extend(clip) + elif clip is None: + assert len(clips) == 0, "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this." + latents.append(latent) + if len(latents) == 0: + return clips, None + else: + latents_0 = torch.stack([l[0] for l in latents], dim=0).mean(dim=0) + latents_1 = torch.stack([l[1] for l in latents], dim=0).mean(dim=0) + latents = (latents_0,latents_1) + return None, latents + + +class TacotronSTFT(torch.nn.Module): + def __init__(self, filter_length=1024, hop_length=256, win_length=1024, + n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0, + mel_fmax=8000.0): + super(TacotronSTFT, self).__init__() + self.n_mel_channels = n_mel_channels + self.sampling_rate = sampling_rate + self.stft_fn = STFT(filter_length, hop_length, win_length) + from librosa.filters import mel as librosa_mel_fn + mel_basis = librosa_mel_fn( + sr=sampling_rate, n_fft=filter_length, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax) + mel_basis = torch.from_numpy(mel_basis).float() + self.register_buffer('mel_basis', mel_basis) + + def spectral_normalize(self, magnitudes): + output = dynamic_range_compression(magnitudes) + return output + + def spectral_de_normalize(self, magnitudes): + output = dynamic_range_decompression(magnitudes) + return output + + def mel_spectrogram(self, y): + """Computes mel-spectrograms from a batch of waves + PARAMS + ------ + y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] + + RETURNS + ------- + mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) + """ + assert(torch.min(y.data) >= -10) + assert(torch.max(y.data) <= 10) + y = torch.clip(y, min=-1, max=1) + + magnitudes, phases = self.stft_fn.transform(y) + magnitudes = magnitudes.data + mel_output = torch.matmul(self.mel_basis, magnitudes) + mel_output = self.spectral_normalize(mel_output) + return mel_output + + +def wav_to_univnet_mel(wav, do_normalization=False, device='cuda' if not torch.backends.mps.is_available() else 'mps'): + stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000) + stft = stft.to(device) + mel = stft.mel_spectrogram(wav) + if do_normalization: + mel = normalize_tacotron_mel(mel) + return mel diff --git a/tortoise/utils/diffusion.py b/tortoise/utils/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..6d4d594d9fd68511e42c52afd271f1688f3f83f2 --- /dev/null +++ b/tortoise/utils/diffusion.py @@ -0,0 +1,1250 @@ +""" +This is an almost carbon copy of gaussian_diffusion.py from OpenAI's ImprovedDiffusion repo, which itself: + +This code started out as a PyTorch port of Ho et al's diffusion models: +https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py + +Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules. +""" + +import enum +import math + +import numpy as np +import torch +import torch as th +from tqdm import tqdm + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [ + x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + th.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * th.exp(-logvar2) + ) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return np.linspace( + beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 + ) + elif schedule_name == "cosine": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = 'previous_x' # the model predicts x_{t-1} + START_X = 'start_x' # the model predicts x_0 + EPSILON = 'epsilon' # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = 'learned' + FIXED_SMALL = 'fixed_small' + FIXED_LARGE = 'fixed_large' + LEARNED_RANGE = 'learned_range' + + +class LossType(enum.Enum): + MSE = 'mse' # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = 'rescaled_mse' # use raw MSE loss (with RESCALED_KL when learning variances) + KL = 'kl' # use the variational lower-bound + RESCALED_KL = 'rescaled_kl' # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + + Ported directly from here, and then adapted over time to further experimentation. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + :param model_mean_type: a ModelMeanType determining what the model outputs. + :param model_var_type: a ModelVarType determining how variance is output. + :param loss_type: a LossType determining the loss function to use. + :param rescale_timesteps: if True, pass floating point timesteps into the + model so that they are always scaled like in the + original paper (0 to 1000). + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type, + rescale_timesteps=False, + conditioning_free=False, + conditioning_free_k=1, + ramp_conditioning_free=True, + ): + self.model_mean_type = ModelMeanType(model_mean_type) + self.model_var_type = ModelVarType(model_var_type) + self.loss_type = LossType(loss_type) + self.rescale_timesteps = rescale_timesteps + self.conditioning_free = conditioning_free + self.conditioning_free_k = conditioning_free_k + self.ramp_conditioning_free = ramp_conditioning_free + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # log calculation clipped because the posterior variance is 0 at the + # beginning of the diffusion chain. + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) + * np.sqrt(alphas) + / (1.0 - self.alphas_cumprod) + ) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + ) + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor( + self.log_one_minus_alphas_cumprod, t, x_start.shape + ) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + + In other words, sample from q(x_t | x_0). + + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + + q(x_{t-1} | x_t, x_0) + + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance( + self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None + ): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B,) + model_output = model(x, self._scale_timesteps(t), **model_kwargs) + if self.conditioning_free: + model_output_no_conditioning = model(x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs) + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + if self.conditioning_free: + model_output_no_conditioning, _ = th.split(model_output_no_conditioning, C, dim=1) + if self.model_var_type == ModelVarType.LEARNED: + model_log_variance = model_var_values + model_variance = th.exp(model_log_variance) + else: + min_log = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x.shape + ) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + if self.conditioning_free: + if self.ramp_conditioning_free: + assert t.shape[0] == 1 # This should only be used in inference. + cfk = self.conditioning_free_k * (1 - self._scale_timesteps(t)[0].item() / self.num_timesteps) + else: + cfk = self.conditioning_free_k + model_output = (1 + cfk) * model_output - cfk * model_output_no_conditioning + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + pred_xstart = process_xstart( + self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) + ) + model_mean = model_output + elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + model_mean, _, _ = self.q_posterior_mean_variance( + x_start=pred_xstart, x_t=x, t=t + ) + else: + raise NotImplementedError(self.model_mean_type) + + assert ( + model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + ) + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_xstart_from_xprev(self, x_t, t, xprev): + assert x_t.shape == xprev.shape + return ( # (xprev - coef2*x_t) / coef1 + _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev + - _extract_into_tensor( + self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape + ) + * x_t + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * (1000.0 / self.num_timesteps) + return t + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) + new_mean = ( + p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + ) + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + + See condition_mean() for details on cond_fn. + + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn( + x, self._scale_timesteps(t), **model_kwargs + ) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance( + x_start=out["pred_xstart"], x_t=x, t=t + ) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean( + cond_fn, out, x, t, model_kwargs=model_kwargs + ) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model. + + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + for i in tqdm(indices, disable=not progress): + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_next) + + th.sqrt(1 - alpha_bar_next) * eps + ) + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices, disable=not progress) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def _vb_terms_bpd( + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None + ): + """ + Get a term for the variational lower-bound. + + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.p_mean_variance( + model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl( + true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] + ) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): + """ + Compute training losses for a single timestep. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + # TODO: support multiple model outputs for this mode. + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs) + if isinstance(model_outputs, tuple): + model_output = model_outputs[0] + terms['extra_outputs'] = model_outputs[1:] + else: + model_output = model_outputs + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert model_output.shape == (B, C * 2, *x_t.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + target = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0] + x_start_pred = torch.zeros(x_start) # Not supported. + elif self.model_mean_type == ModelMeanType.START_X: + target = x_start + x_start_pred = model_output + elif self.model_mean_type == ModelMeanType.EPSILON: + target = noise + x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output) + else: + raise NotImplementedError(self.model_mean_type) + assert model_output.shape == target.shape == x_start.shape + terms["mse"] = mean_flat((target - model_output) ** 2) + terms["x_start_predicted"] = x_start_pred + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def autoregressive_training_losses(self, model, x_start, t, model_output_keys, gd_out_key, model_kwargs=None, noise=None): + """ + Compute training losses for a single timestep. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + terms = {} + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + assert False # not currently supported for this type of diffusion. + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs) + terms.update({k: o for k, o in zip(model_output_keys, model_outputs)}) + model_output = terms[gd_out_key] + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert model_output.shape == (B, C, 2, *x_t.shape[2:]) + model_output, model_var_values = model_output[:, :, 0], model_output[:, :, 1] + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + target = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0] + x_start_pred = torch.zeros(x_start) # Not supported. + elif self.model_mean_type == ModelMeanType.START_X: + target = x_start + x_start_pred = model_output + elif self.model_mean_type == ModelMeanType.EPSILON: + target = noise + x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output) + else: + raise NotImplementedError(self.model_mean_type) + assert model_output.shape == target.shape == x_start.shape + terms["mse"] = mean_flat((target - model_output) ** 2) + terms["x_start_predicted"] = x_start_pred + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + + This term can't be optimized, as it only depends on the encoder. + + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl( + mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 + ) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return np.linspace( + beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 + ) + elif schedule_name == "cosine": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def autoregressive_training_losses( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().autoregressive_training_losses(self._wrap_model(model, True), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model, autoregressive=False): + if isinstance(model, _WrappedModel) or isinstance(model, _WrappedAutoregressiveModel): + return model + mod = _WrappedAutoregressiveModel if autoregressive else _WrappedModel + return mod( + model, self.timestep_map, self.rescale_timesteps, self.original_num_steps + ) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class _WrappedModel: + def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + if self.rescale_timesteps: + new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) + + +class _WrappedAutoregressiveModel: + def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, x0, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + if self.rescale_timesteps: + new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, x0, new_ts, **kwargs) + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr.astype(np.float32)).to(device=timesteps.device)[timesteps] + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res.expand(broadcast_shape) \ No newline at end of file diff --git a/tortoise/utils/stft.py b/tortoise/utils/stft.py new file mode 100644 index 0000000000000000000000000000000000000000..f54eb968225cfe5928cca6d7686abbcc3728a674 --- /dev/null +++ b/tortoise/utils/stft.py @@ -0,0 +1,193 @@ +""" +BSD 3-Clause License + +Copyright (c) 2017, Prem Seetharaman +All rights reserved. + +* Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, this + list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from this + software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +import torch +import numpy as np +import torch.nn.functional as F +from torch.autograd import Variable +from scipy.signal import get_window +from librosa.util import pad_center, tiny +import librosa.util as librosa_util + + +def window_sumsquare(window, n_frames, hop_length=200, win_length=800, + n_fft=800, dtype=np.float32, norm=None): + """ + # from librosa 0.6 + Compute the sum-square envelope of a window function at a given hop length. + + This is used to estimate modulation effects induced by windowing + observations in short-time fourier transforms. + + Parameters + ---------- + window : string, tuple, number, callable, or list-like + Window specification, as in `get_window` + + n_frames : int > 0 + The number of analysis frames + + hop_length : int > 0 + The number of samples to advance between frames + + win_length : [optional] + The length of the window function. By default, this matches `n_fft`. + + n_fft : int > 0 + The length of each analysis frame. + + dtype : np.dtype + The data type of the output + + Returns + ------- + wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` + The sum-squared envelope of the window function + """ + if win_length is None: + win_length = n_fft + + n = n_fft + hop_length * (n_frames - 1) + x = np.zeros(n, dtype=dtype) + + # Compute the squared window at the desired length + win_sq = get_window(window, win_length, fftbins=True) + win_sq = librosa_util.normalize(win_sq, norm=norm)**2 + win_sq = librosa_util.pad_center(win_sq, n_fft) + + # Fill the envelope + for i in range(n_frames): + sample = i * hop_length + x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] + return x + + +class STFT(torch.nn.Module): + """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" + def __init__(self, filter_length=800, hop_length=200, win_length=800, + window='hann'): + super(STFT, self).__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + self.window = window + self.forward_transform = None + scale = self.filter_length / self.hop_length + fourier_basis = np.fft.fft(np.eye(self.filter_length)) + + cutoff = int((self.filter_length / 2 + 1)) + fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), + np.imag(fourier_basis[:cutoff, :])]) + + forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) + inverse_basis = torch.FloatTensor( + np.linalg.pinv(scale * fourier_basis).T[:, None, :]) + + if window is not None: + assert(filter_length >= win_length) + # get window and zero center pad it to filter_length + fft_window = get_window(window, win_length, fftbins=True) + fft_window = pad_center(fft_window, size=filter_length) + fft_window = torch.from_numpy(fft_window).float() + + # window the bases + forward_basis *= fft_window + inverse_basis *= fft_window + + self.register_buffer('forward_basis', forward_basis.float()) + self.register_buffer('inverse_basis', inverse_basis.float()) + + def transform(self, input_data): + num_batches = input_data.size(0) + num_samples = input_data.size(1) + + self.num_samples = num_samples + + # similar to librosa, reflect-pad the input + input_data = input_data.view(num_batches, 1, num_samples) + input_data = F.pad( + input_data.unsqueeze(1), + (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), + mode='reflect') + input_data = input_data.squeeze(1) + + forward_transform = F.conv1d( + input_data, + Variable(self.forward_basis, requires_grad=False), + stride=self.hop_length, + padding=0) + + cutoff = int((self.filter_length / 2) + 1) + real_part = forward_transform[:, :cutoff, :] + imag_part = forward_transform[:, cutoff:, :] + + magnitude = torch.sqrt(real_part**2 + imag_part**2) + phase = torch.autograd.Variable( + torch.atan2(imag_part.data, real_part.data)) + + return magnitude, phase + + def inverse(self, magnitude, phase): + recombine_magnitude_phase = torch.cat( + [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1) + + inverse_transform = F.conv_transpose1d( + recombine_magnitude_phase, + Variable(self.inverse_basis, requires_grad=False), + stride=self.hop_length, + padding=0) + + if self.window is not None: + window_sum = window_sumsquare( + self.window, magnitude.size(-1), hop_length=self.hop_length, + win_length=self.win_length, n_fft=self.filter_length, + dtype=np.float32) + # remove modulation effects + approx_nonzero_indices = torch.from_numpy( + np.where(window_sum > tiny(window_sum))[0]) + window_sum = torch.autograd.Variable( + torch.from_numpy(window_sum), requires_grad=False) + window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum + inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] + + # scale by hop ratio + inverse_transform *= float(self.filter_length) / self.hop_length + + inverse_transform = inverse_transform[:, :, int(self.filter_length/2):] + inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):] + + return inverse_transform + + def forward(self, input_data): + self.magnitude, self.phase = self.transform(input_data) + reconstruction = self.inverse(self.magnitude, self.phase) + return reconstruction \ No newline at end of file diff --git a/tortoise/utils/text.py b/tortoise/utils/text.py new file mode 100644 index 0000000000000000000000000000000000000000..e28c86786b2ca47823a25f3f251f9bc85bb3facd --- /dev/null +++ b/tortoise/utils/text.py @@ -0,0 +1,132 @@ +import re + + +def split_and_recombine_text(text, desired_length=200, max_length=300): + """Split text it into chunks of a desired length trying to keep sentences intact.""" + # normalize text, remove redundant whitespace and convert non-ascii quotes to ascii + text = re.sub(r'\n\n+', '\n', text) + text = re.sub(r'\s+', ' ', text) + text = re.sub(r'[“”]', '"', text) + + rv = [] + in_quote = False + current = "" + split_pos = [] + pos = -1 + end_pos = len(text) - 1 + + def seek(delta): + nonlocal pos, in_quote, current + is_neg = delta < 0 + for _ in range(abs(delta)): + if is_neg: + pos -= 1 + current = current[:-1] + else: + pos += 1 + current += text[pos] + if text[pos] == '"': + in_quote = not in_quote + return text[pos] + + def peek(delta): + p = pos + delta + return text[p] if p < end_pos and p >= 0 else "" + + def commit(): + nonlocal rv, current, split_pos + rv.append(current) + current = "" + split_pos = [] + + while pos < end_pos: + c = seek(1) + # do we need to force a split? + if len(current) >= max_length: + if len(split_pos) > 0 and len(current) > (desired_length / 2): + # we have at least one sentence and we are over half the desired length, seek back to the last split + d = pos - split_pos[-1] + seek(-d) + else: + # no full sentences, seek back until we are not in the middle of a word and split there + while c not in '!?.\n ' and pos > 0 and len(current) > desired_length: + c = seek(-1) + commit() + # check for sentence boundaries + elif not in_quote and (c in '!?\n' or (c == '.' and peek(1) in '\n ')): + # seek forward if we have consecutive boundary markers but still within the max length + while pos < len(text) - 1 and len(current) < max_length and peek(1) in '!?.': + c = seek(1) + split_pos.append(pos) + if len(current) >= desired_length: + commit() + # treat end of quote as a boundary if its followed by a space or newline + elif in_quote and peek(1) == '"' and peek(2) in '\n ': + seek(2) + split_pos.append(pos) + rv.append(current) + + # clean up, remove lines with only whitespace or punctuation + rv = [s.strip() for s in rv] + rv = [s for s in rv if len(s) > 0 and not re.match(r'^[\s\.,;:!?]*$', s)] + + return rv + + +if __name__ == '__main__': + import os + import unittest + + class Test(unittest.TestCase): + def test_split_and_recombine_text(self): + text = """ + This is a sample sentence. + This is another sample sentence. + This is a longer sample sentence that should force a split inthemiddlebutinotinthislongword. + "Don't split my quote... please" + """ + self.assertEqual(split_and_recombine_text(text, desired_length=20, max_length=40), + ['This is a sample sentence.', + 'This is another sample sentence.', + 'This is a longer sample sentence that', + 'should force a split', + 'inthemiddlebutinotinthislongword.', + '"Don\'t split my quote... please"']) + + def test_split_and_recombine_text_2(self): + text = """ + When you are really angry sometimes you use consecutive exclamation marks!!!!!! Is this a good thing to do?!?!?! + I don't know but we should handle this situation.......................... + """ + self.assertEqual(split_and_recombine_text(text, desired_length=30, max_length=50), + ['When you are really angry sometimes you use', + 'consecutive exclamation marks!!!!!!', + 'Is this a good thing to do?!?!?!', + 'I don\'t know but we should handle this situation.']) + + def test_split_and_recombine_text_3(self): + text_src = os.path.join(os.path.dirname(__file__), '../data/riding_hood.txt') + with open(text_src, 'r') as f: + text = f.read() + self.assertEqual( + split_and_recombine_text(text), + [ + 'Once upon a time there lived in a certain village a little country girl, the prettiest creature who was ever seen. Her mother was excessively fond of her; and her grandmother doted on her still more. This good woman had a little red riding hood made for her.', + 'It suited the girl so extremely well that everybody called her Little Red Riding Hood. One day her mother, having made some cakes, said to her, "Go, my dear, and see how your grandmother is doing, for I hear she has been very ill. Take her a cake, and this little pot of butter."', + 'Little Red Riding Hood set out immediately to go to her grandmother, who lived in another village. As she was going through the wood, she met with a wolf, who had a very great mind to eat her up, but he dared not, because of some woodcutters working nearby in the forest.', + 'He asked her where she was going. The poor child, who did not know that it was dangerous to stay and talk to a wolf, said to him, "I am going to see my grandmother and carry her a cake and a little pot of butter from my mother." "Does she live far off?" said the wolf "Oh I say,"', + 'answered Little Red Riding Hood; "it is beyond that mill you see there, at the first house in the village." "Well," said the wolf, "and I\'ll go and see her too. I\'ll go this way and go you that, and we shall see who will be there first."', + 'The wolf ran as fast as he could, taking the shortest path, and the little girl took a roundabout way, entertaining herself by gathering nuts, running after butterflies, and gathering bouquets of little flowers.', + 'It was not long before the wolf arrived at the old woman\'s house. He knocked at the door: tap, tap. "Who\'s there?" "Your grandchild, Little Red Riding Hood," replied the wolf, counterfeiting her voice; "who has brought you a cake and a little pot of butter sent you by mother."', + 'The good grandmother, who was in bed, because she was somewhat ill, cried out, "Pull the bobbin, and the latch will go up."', + 'The wolf pulled the bobbin, and the door opened, and then he immediately fell upon the good woman and ate her up in a moment, for it been more than three days since he had eaten.', + 'He then shut the door and got into the grandmother\'s bed, expecting Little Red Riding Hood, who came some time afterwards and knocked at the door: tap, tap. "Who\'s there?"', + 'Little Red Riding Hood, hearing the big voice of the wolf, was at first afraid; but believing her grandmother had a cold and was hoarse, answered, "It is your grandchild Little Red Riding Hood, who has brought you a cake and a little pot of butter mother sends you."', + 'The wolf cried out to her, softening his voice as much as he could, "Pull the bobbin, and the latch will go up." Little Red Riding Hood pulled the bobbin, and the door opened.', + 'The wolf, seeing her come in, said to her, hiding himself under the bedclothes, "Put the cake and the little pot of butter upon the stool, and come get into bed with me." Little Red Riding Hood took off her clothes and got into bed.', + 'She was greatly amazed to see how her grandmother looked in her nightclothes, and said to her, "Grandmother, what big arms you have!" "All the better to hug you with, my dear." "Grandmother, what big legs you have!" "All the better to run with, my child." "Grandmother, what big ears you have!"', + '"All the better to hear with, my child." "Grandmother, what big eyes you have!" "All the better to see with, my child." "Grandmother, what big teeth you have got!" "All the better to eat you up with." And, saying these words, this wicked wolf fell upon Little Red Riding Hood, and ate her all up.', + ] + ) + + unittest.main() diff --git a/tortoise/utils/tokenizer.py b/tortoise/utils/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..922f23ec2ae383abb653ca01d5b6c008a3b0b5fe --- /dev/null +++ b/tortoise/utils/tokenizer.py @@ -0,0 +1,194 @@ +import os +import re + +import inflect +import torch +from tokenizers import Tokenizer + + +# Regular expression matching whitespace: +from unidecode import unidecode + +_whitespace_re = re.compile(r'\s+') + + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ + ('mrs', 'misess'), + ('mr', 'mister'), + ('dr', 'doctor'), + ('st', 'saint'), + ('co', 'company'), + ('jr', 'junior'), + ('maj', 'major'), + ('gen', 'general'), + ('drs', 'doctors'), + ('rev', 'reverend'), + ('lt', 'lieutenant'), + ('hon', 'honorable'), + ('sgt', 'sergeant'), + ('capt', 'captain'), + ('esq', 'esquire'), + ('ltd', 'limited'), + ('col', 'colonel'), + ('ft', 'fort'), +]] + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +_inflect = inflect.engine() +_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') +_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') +_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') +_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') +_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') +_number_re = re.compile(r'[0-9]+') + + +def _remove_commas(m): + return m.group(1).replace(',', '') + + +def _expand_decimal_point(m): + return m.group(1).replace('.', ' point ') + + +def _expand_dollars(m): + match = m.group(1) + parts = match.split('.') + if len(parts) > 2: + return match + ' dollars' # Unexpected format + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) + elif dollars: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + return '%s %s' % (dollars, dollar_unit) + elif cents: + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s' % (cents, cent_unit) + else: + return 'zero dollars' + + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + +def _expand_number(m): + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return 'two thousand' + elif num > 2000 and num < 2010: + return 'two thousand ' + _inflect.number_to_words(num % 100) + elif num % 100 == 0: + return _inflect.number_to_words(num // 100) + ' hundred' + else: + return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') + else: + return _inflect.number_to_words(num, andword='') + + +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_pounds_re, r'\1 pounds', text) + text = re.sub(_dollars_re, _expand_dollars, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text + + +def expand_numbers(text): + return normalize_numbers(text) + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, ' ', text) + + +def convert_to_ascii(text): + return unidecode(text) + + +def basic_cleaners(text): + '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def transliteration_cleaners(text): + '''Pipeline for non-English text that transliterate to ASCII.''' + text = convert_to_ascii(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def english_cleaners(text): + '''Pipeline for English text, including number and abbreviation expansion.''' + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_numbers(text) + text = expand_abbreviations(text) + text = collapse_whitespace(text) + text = text.replace('"', '') + return text + + +def lev_distance(s1, s2): + if len(s1) > len(s2): + s1, s2 = s2, s1 + + distances = range(len(s1) + 1) + for i2, c2 in enumerate(s2): + distances_ = [i2 + 1] + for i1, c1 in enumerate(s1): + if c1 == c2: + distances_.append(distances[i1]) + else: + distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) + distances = distances_ + return distances[-1] + + +DEFAULT_VOCAB_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data/tokenizer.json') + + +class VoiceBpeTokenizer: + def __init__(self, vocab_file=None, use_basic_cleaners=False): + self.tokenizer = Tokenizer.from_file( + DEFAULT_VOCAB_FILE if vocab_file is None else vocab_file + ) + if use_basic_cleaners: + self.preprocess_text = basic_cleaners + else: + self.preprocess_text = english_cleaners + + def encode(self, txt): + txt = self.preprocess_text(txt) + txt = txt.replace(' ', '[SPACE]') + return self.tokenizer.encode(txt).ids + + def decode(self, seq): + if isinstance(seq, torch.Tensor): + seq = seq.cpu().numpy() + txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(' ', '') + txt = txt.replace('[SPACE]', ' ') + txt = txt.replace('[STOP]', '') + txt = txt.replace('[UNK]', '') + return txt diff --git a/tortoise/utils/typical_sampling.py b/tortoise/utils/typical_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..ff6bf487947e88a55fa45f2ffec1b9540df1d4fd --- /dev/null +++ b/tortoise/utils/typical_sampling.py @@ -0,0 +1,33 @@ +import torch +from transformers import LogitsWarper + + +class TypicalLogitsWarper(LogitsWarper): + def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): + self.filter_value = filter_value + self.mass = mass + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + # calculate entropy + normalized = torch.nn.functional.log_softmax(scores, dim=-1) + p = torch.exp(normalized) + ent = -(normalized * p).nansum(-1, keepdim=True) + + # shift and sort + shifted_scores = torch.abs((-normalized) - ent) + sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) + sorted_logits = scores.gather(-1, sorted_indices) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + + # Remove tokens with cumulative mass above the threshold + last_ind = (cumulative_probs < self.mass).sum(dim=1) + last_ind[last_ind < 0] = 0 + sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) + if self.min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores \ No newline at end of file diff --git a/tortoise/utils/wav2vec_alignment.py b/tortoise/utils/wav2vec_alignment.py new file mode 100644 index 0000000000000000000000000000000000000000..adc39e35e906d3a1bea8655be2aa3c8d13ce2ebb --- /dev/null +++ b/tortoise/utils/wav2vec_alignment.py @@ -0,0 +1,150 @@ +import re + +import torch +import torchaudio +from transformers import Wav2Vec2ForCTC, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer, Wav2Vec2Processor + +from tortoise.utils.audio import load_audio + + +def max_alignment(s1, s2, skip_character='~', record=None): + """ + A clever function that aligns s1 to s2 as best it can. Wherever a character from s1 is not found in s2, a '~' is + used to replace that character. + + Finally got to use my DP skills! + """ + if record is None: + record = {} + assert skip_character not in s1, f"Found the skip character {skip_character} in the provided string, {s1}" + if len(s1) == 0: + return '' + if len(s2) == 0: + return skip_character * len(s1) + if s1 == s2: + return s1 + if s1[0] == s2[0]: + return s1[0] + max_alignment(s1[1:], s2[1:], skip_character, record) + + take_s1_key = (len(s1), len(s2) - 1) + if take_s1_key in record: + take_s1, take_s1_score = record[take_s1_key] + else: + take_s1 = max_alignment(s1, s2[1:], skip_character, record) + take_s1_score = len(take_s1.replace(skip_character, '')) + record[take_s1_key] = (take_s1, take_s1_score) + + take_s2_key = (len(s1) - 1, len(s2)) + if take_s2_key in record: + take_s2, take_s2_score = record[take_s2_key] + else: + take_s2 = max_alignment(s1[1:], s2, skip_character, record) + take_s2_score = len(take_s2.replace(skip_character, '')) + record[take_s2_key] = (take_s2, take_s2_score) + + return take_s1 if take_s1_score > take_s2_score else skip_character + take_s2 + + +class Wav2VecAlignment: + """ + Uses wav2vec2 to perform audio<->text alignment. + """ + def __init__(self, device='cuda' if not torch.backends.mps.is_available() else 'mps'): + self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu() + self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h") + self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron-symbols') + self.device = device + + def align(self, audio, expected_text, audio_sample_rate=24000): + orig_len = audio.shape[-1] + + with torch.no_grad(): + self.model = self.model.to(self.device) + audio = audio.to(self.device) + audio = torchaudio.functional.resample(audio, audio_sample_rate, 16000) + clip_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7) + logits = self.model(clip_norm).logits + self.model = self.model.cpu() + + logits = logits[0] + pred_string = self.tokenizer.decode(logits.argmax(-1).tolist()) + + fixed_expectation = max_alignment(expected_text.lower(), pred_string) + w2v_compression = orig_len // logits.shape[0] + expected_tokens = self.tokenizer.encode(fixed_expectation) + expected_chars = list(fixed_expectation) + if len(expected_tokens) == 1: + return [0] # The alignment is simple; there is only one token. + expected_tokens.pop(0) # The first token is a given. + expected_chars.pop(0) + + alignments = [0] + def pop_till_you_win(): + if len(expected_tokens) == 0: + return None + popped = expected_tokens.pop(0) + popped_char = expected_chars.pop(0) + while popped_char == '~': + alignments.append(-1) + if len(expected_tokens) == 0: + return None + popped = expected_tokens.pop(0) + popped_char = expected_chars.pop(0) + return popped + + next_expected_token = pop_till_you_win() + for i, logit in enumerate(logits): + top = logit.argmax() + if next_expected_token == top: + alignments.append(i * w2v_compression) + if len(expected_tokens) > 0: + next_expected_token = pop_till_you_win() + else: + break + + pop_till_you_win() + if not (len(expected_tokens) == 0 and len(alignments) == len(expected_text)): + torch.save([audio, expected_text], 'alignment_debug.pth') + assert False, "Something went wrong with the alignment algorithm. I've dumped a file, 'alignment_debug.pth' to" \ + "your current working directory. Please report this along with the file so it can get fixed." + + # Now fix up alignments. Anything with -1 should be interpolated. + alignments.append(orig_len) # This'll get removed but makes the algorithm below more readable. + for i in range(len(alignments)): + if alignments[i] == -1: + for j in range(i+1, len(alignments)): + if alignments[j] != -1: + next_found_token = j + break + for j in range(i, next_found_token): + gap = alignments[next_found_token] - alignments[i-1] + alignments[j] = (j-i+1) * gap // (next_found_token-i+1) + alignments[i-1] + + return alignments[:-1] + + def redact(self, audio, expected_text, audio_sample_rate=24000): + if '[' not in expected_text: + return audio + splitted = expected_text.split('[') + fully_split = [splitted[0]] + for spl in splitted[1:]: + assert ']' in spl, 'Every "[" character must be paired with a "]" with no nesting.' + fully_split.extend(spl.split(']')) + + # At this point, fully_split is a list of strings, with every other string being something that should be redacted. + non_redacted_intervals = [] + last_point = 0 + for i in range(len(fully_split)): + if i % 2 == 0 and fully_split[i] != "": # Check for empty string fixes index error + end_interval = max(0, last_point + len(fully_split[i]) - 1) + non_redacted_intervals.append((last_point, end_interval)) + last_point += len(fully_split[i]) + + bare_text = ''.join(fully_split) + alignments = self.align(audio, bare_text, audio_sample_rate) + + output_audio = [] + for nri in non_redacted_intervals: + start, stop = nri + output_audio.append(audio[:, alignments[start]:alignments[stop]]) + return torch.cat(output_audio, dim=-1) diff --git a/tortoise/voices/angie/1.wav b/tortoise/voices/angie/1.wav new file mode 100644 index 0000000000000000000000000000000000000000..6f8480a81995bd81880933e4429ad031004c9766 Binary files /dev/null and b/tortoise/voices/angie/1.wav differ diff --git a/tortoise/voices/angie/3.wav b/tortoise/voices/angie/3.wav new file mode 100644 index 0000000000000000000000000000000000000000..a5e3fc1fe8fd7b57179e4b06cd0414bd7b46ce57 Binary files /dev/null and b/tortoise/voices/angie/3.wav differ diff --git a/tortoise/voices/applejack/1.wav b/tortoise/voices/applejack/1.wav new file mode 100644 index 0000000000000000000000000000000000000000..d82dce2e5af1b470469ef1bde465e7523f635def Binary files /dev/null and b/tortoise/voices/applejack/1.wav differ diff --git a/tortoise/voices/applejack/2.wav b/tortoise/voices/applejack/2.wav new file mode 100644 index 0000000000000000000000000000000000000000..4ed4965d950d7fedc52055586cc35d47e2650ed0 Binary files /dev/null and b/tortoise/voices/applejack/2.wav differ diff --git a/tortoise/voices/applejack/3.wav b/tortoise/voices/applejack/3.wav new file mode 100644 index 0000000000000000000000000000000000000000..6cc51ea65b7c5c842933fae75d02f7abb0ef5278 Binary files /dev/null and b/tortoise/voices/applejack/3.wav differ diff --git a/tortoise/voices/daniel/1.wav b/tortoise/voices/daniel/1.wav new file mode 100644 index 0000000000000000000000000000000000000000..eb0881ef39ff4246ef08f38e805d17ef16295d61 Binary files /dev/null and b/tortoise/voices/daniel/1.wav differ diff --git a/tortoise/voices/daniel/2.wav b/tortoise/voices/daniel/2.wav new file mode 100644 index 0000000000000000000000000000000000000000..677435944afc62d0592a1bf945e6e7d689343c14 Binary files /dev/null and b/tortoise/voices/daniel/2.wav differ diff --git a/tortoise/voices/daniel/3.wav b/tortoise/voices/daniel/3.wav new file mode 100644 index 0000000000000000000000000000000000000000..9c1d7e928ee278396756d5c080c426e6abaaf403 Binary files /dev/null and b/tortoise/voices/daniel/3.wav differ diff --git a/tortoise/voices/daniel/4.wav b/tortoise/voices/daniel/4.wav new file mode 100644 index 0000000000000000000000000000000000000000..bf34df273602c8daaad1245945784f827e3ddea0 Binary files /dev/null and b/tortoise/voices/daniel/4.wav differ diff --git a/tortoise/voices/deniro/1.wav b/tortoise/voices/deniro/1.wav new file mode 100644 index 0000000000000000000000000000000000000000..391a4878e9f76a84b5f917de20fc8bd303071ad5 Binary files /dev/null and b/tortoise/voices/deniro/1.wav differ diff --git a/tortoise/voices/deniro/3.wav b/tortoise/voices/deniro/3.wav new file mode 100644 index 0000000000000000000000000000000000000000..bbb5737b06b5bb75f76d105cadc3907575c5a880 Binary files /dev/null and b/tortoise/voices/deniro/3.wav differ diff --git a/tortoise/voices/deniro/4.wav b/tortoise/voices/deniro/4.wav new file mode 100644 index 0000000000000000000000000000000000000000..9847b6edecbed48746eecfb7ce24f187817525a8 Binary files /dev/null and b/tortoise/voices/deniro/4.wav differ diff --git a/tortoise/voices/emma/1.wav b/tortoise/voices/emma/1.wav new file mode 100644 index 0000000000000000000000000000000000000000..5acab208fc1e9bcc6281fad4e8feb7996ee604a2 Binary files /dev/null and b/tortoise/voices/emma/1.wav differ diff --git a/tortoise/voices/emma/2.wav b/tortoise/voices/emma/2.wav new file mode 100644 index 0000000000000000000000000000000000000000..ca7bfe92d369763ade070af32f43d487cb5ba674 Binary files /dev/null and b/tortoise/voices/emma/2.wav differ diff --git a/tortoise/voices/emma/3.wav b/tortoise/voices/emma/3.wav new file mode 100644 index 0000000000000000000000000000000000000000..5b065fce7b70073d713e7b40856d1ba3b0f6c919 Binary files /dev/null and b/tortoise/voices/emma/3.wav differ diff --git a/tortoise/voices/freeman/1.wav b/tortoise/voices/freeman/1.wav new file mode 100644 index 0000000000000000000000000000000000000000..0b6941e43e50025d177986f7176452144fb057c1 Binary files /dev/null and b/tortoise/voices/freeman/1.wav differ diff --git a/tortoise/voices/freeman/2.wav b/tortoise/voices/freeman/2.wav new file mode 100644 index 0000000000000000000000000000000000000000..7377fd0852bfe9f35d559c05ab5342e598e25bd2 Binary files /dev/null and b/tortoise/voices/freeman/2.wav differ diff --git a/tortoise/voices/freeman/3.wav b/tortoise/voices/freeman/3.wav new file mode 100644 index 0000000000000000000000000000000000000000..889cee806a08efc2de73caf488b6bc4bf352f497 Binary files /dev/null and b/tortoise/voices/freeman/3.wav differ diff --git a/tortoise/voices/geralt/1.wav b/tortoise/voices/geralt/1.wav new file mode 100644 index 0000000000000000000000000000000000000000..b263cf4908ddb8c1a8b2cde3583304e2a0623fdd Binary files /dev/null and b/tortoise/voices/geralt/1.wav differ diff --git a/tortoise/voices/geralt/2.wav b/tortoise/voices/geralt/2.wav new file mode 100644 index 0000000000000000000000000000000000000000..953459a8c2d8e5139225c63abac101521c3e212a Binary files /dev/null and b/tortoise/voices/geralt/2.wav differ diff --git a/tortoise/voices/geralt/3.wav b/tortoise/voices/geralt/3.wav new file mode 100644 index 0000000000000000000000000000000000000000..5a40836146bb2cab346df20a5ffb90ba9cbd63e0 Binary files /dev/null and b/tortoise/voices/geralt/3.wav differ diff --git a/tortoise/voices/halle/1.wav b/tortoise/voices/halle/1.wav new file mode 100644 index 0000000000000000000000000000000000000000..a023dab1c784225c7501d8bf3a3c5b67d0e46109 Binary files /dev/null and b/tortoise/voices/halle/1.wav differ diff --git a/tortoise/voices/halle/2.wav b/tortoise/voices/halle/2.wav new file mode 100644 index 0000000000000000000000000000000000000000..07f738a7fe4e37d9e12a48ebbc563eaee65a7012 Binary files /dev/null and b/tortoise/voices/halle/2.wav differ diff --git a/tortoise/voices/halle/3.wav b/tortoise/voices/halle/3.wav new file mode 100644 index 0000000000000000000000000000000000000000..8b7914492d254157f95d0f1d00095abf6b7521c7 Binary files /dev/null and b/tortoise/voices/halle/3.wav differ diff --git a/tortoise/voices/jlaw/1.wav b/tortoise/voices/jlaw/1.wav new file mode 100644 index 0000000000000000000000000000000000000000..e749d0e494cf771b332437b924a4af0b8b92977d Binary files /dev/null and b/tortoise/voices/jlaw/1.wav differ diff --git a/tortoise/voices/jlaw/2.wav b/tortoise/voices/jlaw/2.wav new file mode 100644 index 0000000000000000000000000000000000000000..7dd51de3468e815baf65e343ea4362614120bd11 Binary files /dev/null and b/tortoise/voices/jlaw/2.wav differ diff --git a/tortoise/voices/jlaw/3.wav b/tortoise/voices/jlaw/3.wav new file mode 100644 index 0000000000000000000000000000000000000000..429230f47cba63ffc0282f4af0e1e19962069c25 Binary files /dev/null and b/tortoise/voices/jlaw/3.wav differ diff --git a/tortoise/voices/jlaw/4.wav b/tortoise/voices/jlaw/4.wav new file mode 100644 index 0000000000000000000000000000000000000000..e475993dd2ebfbe01293f8abaffba3482e7edc09 Binary files /dev/null and b/tortoise/voices/jlaw/4.wav differ diff --git a/tortoise/voices/lj/1.wav b/tortoise/voices/lj/1.wav new file mode 100644 index 0000000000000000000000000000000000000000..5d86776a4dd406fee2cfb07f87ddf09431f075a0 Binary files /dev/null and b/tortoise/voices/lj/1.wav differ diff --git a/tortoise/voices/lj/2.wav b/tortoise/voices/lj/2.wav new file mode 100644 index 0000000000000000000000000000000000000000..75d66e4355401924a79260b7f990145288fd0c14 Binary files /dev/null and b/tortoise/voices/lj/2.wav differ diff --git a/tortoise/voices/mol/1.wav b/tortoise/voices/mol/1.wav new file mode 100644 index 0000000000000000000000000000000000000000..b3244a5456f94172f199d4853fffcddc9351caf3 Binary files /dev/null and b/tortoise/voices/mol/1.wav differ diff --git a/tortoise/voices/mol/2.wav b/tortoise/voices/mol/2.wav new file mode 100644 index 0000000000000000000000000000000000000000..b6d3928e189780584870292d43f4e1d53503a2bd Binary files /dev/null and b/tortoise/voices/mol/2.wav differ diff --git a/tortoise/voices/myself/1.wav b/tortoise/voices/myself/1.wav new file mode 100644 index 0000000000000000000000000000000000000000..83b7f804f43ca482e68c569bfafc92d5bd4121a0 Binary files /dev/null and b/tortoise/voices/myself/1.wav differ diff --git a/tortoise/voices/myself/2.wav b/tortoise/voices/myself/2.wav new file mode 100644 index 0000000000000000000000000000000000000000..faec7ddf078e7a340cb54aa399da21c0f9b21c0d Binary files /dev/null and b/tortoise/voices/myself/2.wav differ diff --git a/tortoise/voices/myself/3.wav b/tortoise/voices/myself/3.wav new file mode 100644 index 0000000000000000000000000000000000000000..374c799874a4d0940fe5df2bd590daf2d5244e78 Binary files /dev/null and b/tortoise/voices/myself/3.wav differ diff --git a/tortoise/voices/pat/1.wav b/tortoise/voices/pat/1.wav new file mode 100644 index 0000000000000000000000000000000000000000..8c80c24e1bb9fe59713c9b3f7c320a8568aed7ae Binary files /dev/null and b/tortoise/voices/pat/1.wav differ diff --git a/tortoise/voices/pat/2.wav b/tortoise/voices/pat/2.wav new file mode 100644 index 0000000000000000000000000000000000000000..5503b1c0f2fb1b327eb7b8f7e43123702f916389 Binary files /dev/null and b/tortoise/voices/pat/2.wav differ diff --git a/tortoise/voices/pat/3.wav b/tortoise/voices/pat/3.wav new file mode 100644 index 0000000000000000000000000000000000000000..ec4e85396b5eaa59d12967212cdf621346707a37 Binary files /dev/null and b/tortoise/voices/pat/3.wav differ diff --git a/tortoise/voices/pat/4.wav b/tortoise/voices/pat/4.wav new file mode 100644 index 0000000000000000000000000000000000000000..5949dd2d1200b8103aeaaada8262ee359cad4669 Binary files /dev/null and b/tortoise/voices/pat/4.wav differ diff --git a/tortoise/voices/pat2/00100.mp3 b/tortoise/voices/pat2/00100.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..fd50dc458de5ee97b79124fda28c87ff3a28cbe3 Binary files /dev/null and b/tortoise/voices/pat2/00100.mp3 differ diff --git a/tortoise/voices/pat2/00112.mp3 b/tortoise/voices/pat2/00112.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..4b27bef85112dc59da0f44d7f4ee5a65851bf4c9 Binary files /dev/null and b/tortoise/voices/pat2/00112.mp3 differ diff --git a/tortoise/voices/pat2/00130.mp3 b/tortoise/voices/pat2/00130.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..36b5e5487dffbeade03737e87fe207aef971981d Binary files /dev/null and b/tortoise/voices/pat2/00130.mp3 differ diff --git a/tortoise/voices/pat2/00159.mp3 b/tortoise/voices/pat2/00159.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..65b41e0dbc043527eafb11862d98022039f2df46 Binary files /dev/null and b/tortoise/voices/pat2/00159.mp3 differ diff --git a/tortoise/voices/rainbow/1.wav b/tortoise/voices/rainbow/1.wav new file mode 100644 index 0000000000000000000000000000000000000000..f32e14d8d4e5e22129ffc7388d2e0c1bd05a959e Binary files /dev/null and b/tortoise/voices/rainbow/1.wav differ diff --git a/tortoise/voices/rainbow/2.wav b/tortoise/voices/rainbow/2.wav new file mode 100644 index 0000000000000000000000000000000000000000..1eae99c5be7e8ac186b8eb3b25086b521627e1e6 Binary files /dev/null and b/tortoise/voices/rainbow/2.wav differ diff --git a/tortoise/voices/rainbow/3.wav b/tortoise/voices/rainbow/3.wav new file mode 100644 index 0000000000000000000000000000000000000000..71bc300839e52e3fb4eba1dca80addb264abdbda Binary files /dev/null and b/tortoise/voices/rainbow/3.wav differ diff --git a/tortoise/voices/rainbow/4.wav b/tortoise/voices/rainbow/4.wav new file mode 100644 index 0000000000000000000000000000000000000000..d878e218ae3d2f69dd1d05118515e13c4aa2c56a Binary files /dev/null and b/tortoise/voices/rainbow/4.wav differ diff --git a/tortoise/voices/rainbow/5.wav b/tortoise/voices/rainbow/5.wav new file mode 100644 index 0000000000000000000000000000000000000000..f6d9cc46c2949277262571249bad4a0e45a5746b Binary files /dev/null and b/tortoise/voices/rainbow/5.wav differ diff --git a/tortoise/voices/snakes/00115.mp3 b/tortoise/voices/snakes/00115.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..e9770ba39eb810f439b7f417acf8d10f2c60b931 Binary files /dev/null and b/tortoise/voices/snakes/00115.mp3 differ diff --git a/tortoise/voices/snakes/00162.mp3 b/tortoise/voices/snakes/00162.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..503aa7274d23ce55d64d5bbcc9afaf331370b8e0 Binary files /dev/null and b/tortoise/voices/snakes/00162.mp3 differ diff --git a/tortoise/voices/snakes/03504.mp3 b/tortoise/voices/snakes/03504.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..bd4f03946102f92f89a487a1421efaccfbea344d Binary files /dev/null and b/tortoise/voices/snakes/03504.mp3 differ diff --git a/tortoise/voices/tim_reynolds/1.mp3 b/tortoise/voices/tim_reynolds/1.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..445db30cf554422be9b76b20617d264cb715f173 Binary files /dev/null and b/tortoise/voices/tim_reynolds/1.mp3 differ diff --git a/tortoise/voices/tim_reynolds/2.mp3 b/tortoise/voices/tim_reynolds/2.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..6f09722dedffdd6829a13327d7536bb6938e3aea Binary files /dev/null and b/tortoise/voices/tim_reynolds/2.mp3 differ diff --git a/tortoise/voices/tim_reynolds/3.mp3 b/tortoise/voices/tim_reynolds/3.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..717a7ed0e6fd7488903043abade196c933ebad2e Binary files /dev/null and b/tortoise/voices/tim_reynolds/3.mp3 differ diff --git a/tortoise/voices/tim_reynolds/4.mp3 b/tortoise/voices/tim_reynolds/4.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..458d8121e810881589fd3f0d96ab27a4749e81f8 Binary files /dev/null and b/tortoise/voices/tim_reynolds/4.mp3 differ diff --git a/tortoise/voices/tom/1.wav b/tortoise/voices/tom/1.wav new file mode 100644 index 0000000000000000000000000000000000000000..4e91bf9d20e2cbea2cdced3ad1133d8d7f2f6765 Binary files /dev/null and b/tortoise/voices/tom/1.wav differ diff --git a/tortoise/voices/tom/2.wav b/tortoise/voices/tom/2.wav new file mode 100644 index 0000000000000000000000000000000000000000..fb3d38d47cfcc4debd4d1cab7c4788ba8c39d0d7 Binary files /dev/null and b/tortoise/voices/tom/2.wav differ diff --git a/tortoise/voices/tom/3.wav b/tortoise/voices/tom/3.wav new file mode 100644 index 0000000000000000000000000000000000000000..07b0b14ceed1c0ff414a7473c9255bb31667ea40 Binary files /dev/null and b/tortoise/voices/tom/3.wav differ diff --git a/tortoise/voices/tom/4.wav b/tortoise/voices/tom/4.wav new file mode 100644 index 0000000000000000000000000000000000000000..0c64b0ebeb84cf4dd489b1073596f3df0b145538 Binary files /dev/null and b/tortoise/voices/tom/4.wav differ diff --git a/tortoise/voices/train_atkins/1.wav b/tortoise/voices/train_atkins/1.wav new file mode 100644 index 0000000000000000000000000000000000000000..cf721d3f6c9998e000f1b30ce27f6c86c698cf05 Binary files /dev/null and b/tortoise/voices/train_atkins/1.wav differ diff --git a/tortoise/voices/train_atkins/2.wav b/tortoise/voices/train_atkins/2.wav new file mode 100644 index 0000000000000000000000000000000000000000..096b0b277908ff857c33e3de43d98339306e1651 Binary files /dev/null and b/tortoise/voices/train_atkins/2.wav differ diff --git a/tortoise/voices/train_daws/1.mp3 b/tortoise/voices/train_daws/1.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..4f2dbb0e8f5eda44685e74f2619d56f2b6fefbdb Binary files /dev/null and b/tortoise/voices/train_daws/1.mp3 differ diff --git a/tortoise/voices/train_daws/2.mp3 b/tortoise/voices/train_daws/2.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..f754f03885e627bbd4bc5c0d936271aad1aa4c18 Binary files /dev/null and b/tortoise/voices/train_daws/2.mp3 differ diff --git a/tortoise/voices/train_daws/3.mp3 b/tortoise/voices/train_daws/3.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..d9dace8f90114c72a5bb28c3083e75a5d2f0e604 Binary files /dev/null and b/tortoise/voices/train_daws/3.mp3 differ diff --git a/tortoise/voices/train_dotrice/1.wav b/tortoise/voices/train_dotrice/1.wav new file mode 100644 index 0000000000000000000000000000000000000000..7babde7c9827dbdeb3bdbd30b0e7f62e98fe0f33 Binary files /dev/null and b/tortoise/voices/train_dotrice/1.wav differ diff --git a/tortoise/voices/train_dotrice/2.wav b/tortoise/voices/train_dotrice/2.wav new file mode 100644 index 0000000000000000000000000000000000000000..8f41a826e2de82263e22e49ef67ec6f05d79a74f Binary files /dev/null and b/tortoise/voices/train_dotrice/2.wav differ diff --git a/tortoise/voices/train_dreams/1.mp3 b/tortoise/voices/train_dreams/1.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..f820e287f0af56d9eef3ae0125602c10d41cdd1b Binary files /dev/null and b/tortoise/voices/train_dreams/1.mp3 differ diff --git a/tortoise/voices/train_dreams/2.mp3 b/tortoise/voices/train_dreams/2.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..fbdd0ff8aef7d283bcef382b1639ae79aafdfcb8 Binary files /dev/null and b/tortoise/voices/train_dreams/2.mp3 differ diff --git a/tortoise/voices/train_dreams/3.mp3 b/tortoise/voices/train_dreams/3.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..2b73e0690a5929453980613c366b9300d7a5effc Binary files /dev/null and b/tortoise/voices/train_dreams/3.mp3 differ diff --git a/tortoise/voices/train_empire/1.mp3 b/tortoise/voices/train_empire/1.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..de570b857522709f2016fc3972f466232068fcd8 Binary files /dev/null and b/tortoise/voices/train_empire/1.mp3 differ diff --git a/tortoise/voices/train_empire/2.mp3 b/tortoise/voices/train_empire/2.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..0a59abd2fc4f459c783ed2adf8582f4d69d531c2 Binary files /dev/null and b/tortoise/voices/train_empire/2.mp3 differ diff --git a/tortoise/voices/train_empire/3.mp3 b/tortoise/voices/train_empire/3.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..674ad228caa0f16b1a909bc8e321b8c1dd9a8af0 Binary files /dev/null and b/tortoise/voices/train_empire/3.mp3 differ diff --git a/tortoise/voices/train_grace/1.wav b/tortoise/voices/train_grace/1.wav new file mode 100644 index 0000000000000000000000000000000000000000..b2a243cea416cad34ea374b3964ecb0f8774211e Binary files /dev/null and b/tortoise/voices/train_grace/1.wav differ diff --git a/tortoise/voices/train_grace/2.wav b/tortoise/voices/train_grace/2.wav new file mode 100644 index 0000000000000000000000000000000000000000..41ca66e8b6ef9d65235c5a6bf42bdff8afeee42d Binary files /dev/null and b/tortoise/voices/train_grace/2.wav differ diff --git a/tortoise/voices/train_kennard/1.wav b/tortoise/voices/train_kennard/1.wav new file mode 100644 index 0000000000000000000000000000000000000000..d98ca272441703d70a195f2c098a78a4ff6f100e Binary files /dev/null and b/tortoise/voices/train_kennard/1.wav differ diff --git a/tortoise/voices/train_kennard/2.wav b/tortoise/voices/train_kennard/2.wav new file mode 100644 index 0000000000000000000000000000000000000000..9548fb907cadf579d96393b8e2d99c8d378eeaa4 Binary files /dev/null and b/tortoise/voices/train_kennard/2.wav differ diff --git a/tortoise/voices/train_lescault/lescault_new1.wav b/tortoise/voices/train_lescault/lescault_new1.wav new file mode 100644 index 0000000000000000000000000000000000000000..56673ae8933b5c6345e6780566a61dfc07d1ba9b Binary files /dev/null and b/tortoise/voices/train_lescault/lescault_new1.wav differ diff --git a/tortoise/voices/train_lescault/lescault_new2.wav b/tortoise/voices/train_lescault/lescault_new2.wav new file mode 100644 index 0000000000000000000000000000000000000000..5ef7635f78f021cfa54a90b7804d432cb99a8aa2 Binary files /dev/null and b/tortoise/voices/train_lescault/lescault_new2.wav differ diff --git a/tortoise/voices/train_lescault/lescault_new3.wav b/tortoise/voices/train_lescault/lescault_new3.wav new file mode 100644 index 0000000000000000000000000000000000000000..85f416eca6316bfa6553ea329d126fcaca57b2c1 Binary files /dev/null and b/tortoise/voices/train_lescault/lescault_new3.wav differ diff --git a/tortoise/voices/train_lescault/lescault_new4.wav b/tortoise/voices/train_lescault/lescault_new4.wav new file mode 100644 index 0000000000000000000000000000000000000000..0e290cf2422199918c93f40d46d07b9c11f35716 --- /dev/null +++ b/tortoise/voices/train_lescault/lescault_new4.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b36a0978103fcbdaa89486de63f5eabfcc453cf94c6c6c85a0917b5ace3c4f5f +size 1018880 diff --git a/tortoise/voices/train_lescault/lescault_new5.wav b/tortoise/voices/train_lescault/lescault_new5.wav new file mode 100644 index 0000000000000000000000000000000000000000..17496bf2b0b67ebbaff2877cb948cec43932d17b Binary files /dev/null and b/tortoise/voices/train_lescault/lescault_new5.wav differ diff --git a/tortoise/voices/train_mouse/1.mp3 b/tortoise/voices/train_mouse/1.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..937f1820108673b4c7eeeab6585c196839735d18 Binary files /dev/null and b/tortoise/voices/train_mouse/1.mp3 differ diff --git a/tortoise/voices/train_mouse/2.mp3 b/tortoise/voices/train_mouse/2.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..275d90f252728cef5bb402375b3f50aabb8ec617 Binary files /dev/null and b/tortoise/voices/train_mouse/2.mp3 differ diff --git a/tortoise/voices/weaver/1.wav b/tortoise/voices/weaver/1.wav new file mode 100644 index 0000000000000000000000000000000000000000..7283087f36455bf350198051b1ecbfce5132df39 Binary files /dev/null and b/tortoise/voices/weaver/1.wav differ diff --git a/tortoise/voices/weaver/2.wav b/tortoise/voices/weaver/2.wav new file mode 100644 index 0000000000000000000000000000000000000000..de7206e1e8c691995a357d25854547f4bd26d367 Binary files /dev/null and b/tortoise/voices/weaver/2.wav differ diff --git a/tortoise/voices/weaver/3.wav b/tortoise/voices/weaver/3.wav new file mode 100644 index 0000000000000000000000000000000000000000..6b4b4feb435dda18d8c8b8d27622c32c70819d3a Binary files /dev/null and b/tortoise/voices/weaver/3.wav differ diff --git a/tortoise/voices/william/1.wav b/tortoise/voices/william/1.wav new file mode 100644 index 0000000000000000000000000000000000000000..15ef32bc8f2f26ccfba93bb8ae69a408192171ac Binary files /dev/null and b/tortoise/voices/william/1.wav differ diff --git a/tortoise/voices/william/2.wav b/tortoise/voices/william/2.wav new file mode 100644 index 0000000000000000000000000000000000000000..f72eb62174838d2fd690b90122646da6d284603b Binary files /dev/null and b/tortoise/voices/william/2.wav differ diff --git a/tortoise/voices/william/3.wav b/tortoise/voices/william/3.wav new file mode 100644 index 0000000000000000000000000000000000000000..d9b4002061110042e0713d8386e97334226d3de4 Binary files /dev/null and b/tortoise/voices/william/3.wav differ diff --git a/tortoise/voices/william/4.wav b/tortoise/voices/william/4.wav new file mode 100644 index 0000000000000000000000000000000000000000..e03c1812224bbbd9cd54e260e25f953e92448a44 Binary files /dev/null and b/tortoise/voices/william/4.wav differ diff --git a/wav2lip/__init__.py b/wav2lip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/wav2lip/__pycache__/__init__.cpython-310.pyc b/wav2lip/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6aa0fbc9f35574271508bdfee7c2ed45cf480261 Binary files /dev/null and b/wav2lip/__pycache__/__init__.cpython-310.pyc differ diff --git a/wav2lip/__pycache__/audio.cpython-310.pyc b/wav2lip/__pycache__/audio.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33f36072229c796aa39b0223de7e55f732967eb6 Binary files /dev/null and b/wav2lip/__pycache__/audio.cpython-310.pyc differ diff --git a/wav2lip/__pycache__/hparams.cpython-310.pyc b/wav2lip/__pycache__/hparams.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60fae5a60b0243310c281552e99f65805844e7ec Binary files /dev/null and b/wav2lip/__pycache__/hparams.cpython-310.pyc differ diff --git a/wav2lip/audio.py b/wav2lip/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..ca7a35b445e9742fcfd93d326994e8b2be978ab2 --- /dev/null +++ b/wav2lip/audio.py @@ -0,0 +1,141 @@ +import librosa +import librosa.filters +import numpy as np +# import tensorflow as tf +from scipy import signal +from scipy.io import wavfile +from wav2lip.hparams import hparams as hp + +def load_wav(path, sr): + return librosa.core.load(path, sr=sr)[0] + +def save_wav(wav, path, sr): + wav *= 32767 / max(0.01, np.max(np.abs(wav))) + #proposed by @dsmiller + wavfile.write(path, sr, wav.astype(np.int16)) + +def save_wavenet_wav(wav, path, sr): + librosa.output.write_wav(path, wav, sr=sr) + +def preemphasis(wav, k, preemphasize=True): + if preemphasize: + return signal.lfilter([1, -k], [1], wav) + return wav + +def inv_preemphasis(wav, k, inv_preemphasize=True): + if inv_preemphasize: + return signal.lfilter([1], [1, -k], wav) + return wav + +def get_hop_size(): + hop_size = hp.hop_size + if hop_size is None: + assert hp.frame_shift_ms is not None + hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate) + return hop_size + +def linearspectrogram(wav): + D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) + S = _amp_to_db(np.abs(D)) - hp.ref_level_db + + if hp.signal_normalization: + return _normalize(S) + return S + +def melspectrogram(wav): + D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) + S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db + + if hp.signal_normalization: + return _normalize(S) + return S + +def _lws_processor(): + import lws + return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech") + +def _stft(y): + if hp.use_lws: + return _lws_processor(hp).stft(y).T + else: + return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size) + +########################################################## +#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!) +def num_frames(length, fsize, fshift): + """Compute number of time frames of spectrogram + """ + pad = (fsize - fshift) + if length % fshift == 0: + M = (length + pad * 2 - fsize) // fshift + 1 + else: + M = (length + pad * 2 - fsize) // fshift + 2 + return M + + +def pad_lr(x, fsize, fshift): + """Compute left and right padding + """ + M = num_frames(len(x), fsize, fshift) + pad = (fsize - fshift) + T = len(x) + 2 * pad + r = (M - 1) * fshift + fsize - T + return pad, pad + r +########################################################## +#Librosa correct padding +def librosa_pad_lr(x, fsize, fshift): + return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0] + +# Conversions +_mel_basis = None + +def _linear_to_mel(spectogram): + global _mel_basis + if _mel_basis is None: + _mel_basis = _build_mel_basis() + return np.dot(_mel_basis, spectogram) + +# def _build_mel_basis(): +# assert hp.fmax <= hp.sample_rate // 2 +# return librosa.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.num_mels, +# fmin=hp.fmin, fmax=hp.fmax) + +def _build_mel_basis(): + assert hp.fmax <= hp.sample_rate // 2 + return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels, + fmin=hp.fmin, fmax=hp.fmax) + +def _amp_to_db(x): + min_level = np.exp(hp.min_level_db / 20 * np.log(10)) + return 20 * np.log10(np.maximum(min_level, x)) + +def _db_to_amp(x): + return np.power(10.0, (x) * 0.05) + +def _normalize(S): + if hp.allow_clipping_in_normalization: + if hp.symmetric_mels: + return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value, + -hp.max_abs_value, hp.max_abs_value) + else: + return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value) + + assert S.max() <= 0 and S.min() - hp.min_level_db >= 0 + if hp.symmetric_mels: + return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value + else: + return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)) + +def _denormalize(D): + if hp.allow_clipping_in_normalization: + if hp.symmetric_mels: + return (((np.clip(D, -hp.max_abs_value, + hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + + hp.min_level_db) + else: + return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) + + if hp.symmetric_mels: + return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db) + else: + return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) diff --git a/wav2lip/face_detection/README.md b/wav2lip/face_detection/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c073376e4eeda6d4b29cc31c50cb7e88ab42bb73 --- /dev/null +++ b/wav2lip/face_detection/README.md @@ -0,0 +1 @@ +The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time. \ No newline at end of file diff --git a/wav2lip/face_detection/__init__.py b/wav2lip/face_detection/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4bae29fd5f85b41e4669302bd2603bc6924eddc7 --- /dev/null +++ b/wav2lip/face_detection/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +__author__ = """Adrian Bulat""" +__email__ = 'adrian.bulat@nottingham.ac.uk' +__version__ = '1.0.1' + +from .api import FaceAlignment, LandmarksType, NetworkSize diff --git a/wav2lip/face_detection/__pycache__/__init__.cpython-310.pyc b/wav2lip/face_detection/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa3abc491a75e78eed1840830eda6ee095a01e84 Binary files /dev/null and b/wav2lip/face_detection/__pycache__/__init__.cpython-310.pyc differ diff --git a/wav2lip/face_detection/__pycache__/api.cpython-310.pyc b/wav2lip/face_detection/__pycache__/api.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ee303c3c69cde5b92c589fe9e8e7339a61393d1 Binary files /dev/null and b/wav2lip/face_detection/__pycache__/api.cpython-310.pyc differ diff --git a/wav2lip/face_detection/__pycache__/models.cpython-310.pyc b/wav2lip/face_detection/__pycache__/models.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cf6eb23ba82157d1718b30ddadbc175c04d5f37 Binary files /dev/null and b/wav2lip/face_detection/__pycache__/models.cpython-310.pyc differ diff --git a/wav2lip/face_detection/__pycache__/utils.cpython-310.pyc b/wav2lip/face_detection/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b4102292623db37507cef19f62709c9e7d39db1 Binary files /dev/null and b/wav2lip/face_detection/__pycache__/utils.cpython-310.pyc differ diff --git a/wav2lip/face_detection/api.py b/wav2lip/face_detection/api.py new file mode 100644 index 0000000000000000000000000000000000000000..2cbda02c4b268fdac446614c240d9d2d56092708 --- /dev/null +++ b/wav2lip/face_detection/api.py @@ -0,0 +1,79 @@ +from __future__ import print_function +import os +import torch +from torch.utils.model_zoo import load_url +from enum import Enum +import numpy as np +import cv2 +try: + import urllib.request as request_file +except BaseException: + import urllib as request_file + +from .models import FAN, ResNetDepth +from .utils import * + + +class LandmarksType(Enum): + """Enum class defining the type of landmarks to detect. + + ``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face + ``_2halfD`` - this points represent the projection of the 3D points into 3D + ``_3D`` - detect the points ``(x,y,z)``` in a 3D space + + """ + _2D = 1 + _2halfD = 2 + _3D = 3 + + +class NetworkSize(Enum): + # TINY = 1 + # SMALL = 2 + # MEDIUM = 3 + LARGE = 4 + + def __new__(cls, value): + member = object.__new__(cls) + member._value_ = value + return member + + def __int__(self): + return self.value + +ROOT = os.path.dirname(os.path.abspath(__file__)) + +class FaceAlignment: + def __init__(self, landmarks_type, network_size=NetworkSize.LARGE, + device='cuda', flip_input=False, face_detector='sfd', verbose=False): + self.device = device + self.flip_input = flip_input + self.landmarks_type = landmarks_type + self.verbose = verbose + + network_size = int(network_size) + + if 'cuda' in device: + torch.backends.cudnn.benchmark = True + + # Get the face detector + face_detector_module = __import__('wav2lip.face_detection.detection.' + face_detector, + globals(), locals(), [face_detector], 0) + self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose) + + def get_detections_for_batch(self, images): + images = images[..., ::-1] + detected_faces = self.face_detector.detect_from_batch(images.copy()) + results = [] + + for i, d in enumerate(detected_faces): + if len(d) == 0: + results.append(None) + continue + d = d[0] + d = np.clip(d, 0, None) + + x1, y1, x2, y2 = map(int, d[:-1]) + results.append((x1, y1, x2, y2)) + + return results \ No newline at end of file diff --git a/wav2lip/face_detection/detection/__init__.py b/wav2lip/face_detection/detection/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1a6b0402dae864a3cc5dc2a90a412fd842a0efc7 --- /dev/null +++ b/wav2lip/face_detection/detection/__init__.py @@ -0,0 +1 @@ +from .core import FaceDetector \ No newline at end of file diff --git a/wav2lip/face_detection/detection/__pycache__/__init__.cpython-310.pyc b/wav2lip/face_detection/detection/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ec14ba44f7c9bf8e0d965dcf509a72c30d75ae1 Binary files /dev/null and b/wav2lip/face_detection/detection/__pycache__/__init__.cpython-310.pyc differ diff --git a/wav2lip/face_detection/detection/__pycache__/core.cpython-310.pyc b/wav2lip/face_detection/detection/__pycache__/core.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6a20f2e209ecc5929f3f24cdae860c7e6d46426 Binary files /dev/null and b/wav2lip/face_detection/detection/__pycache__/core.cpython-310.pyc differ diff --git a/wav2lip/face_detection/detection/core.py b/wav2lip/face_detection/detection/core.py new file mode 100644 index 0000000000000000000000000000000000000000..0f8275e8e53143f66298f75f0517c234a68778cd --- /dev/null +++ b/wav2lip/face_detection/detection/core.py @@ -0,0 +1,130 @@ +import logging +import glob +from tqdm import tqdm +import numpy as np +import torch +import cv2 + + +class FaceDetector(object): + """An abstract class representing a face detector. + + Any other face detection implementation must subclass it. All subclasses + must implement ``detect_from_image``, that return a list of detected + bounding boxes. Optionally, for speed considerations detect from path is + recommended. + """ + + def __init__(self, device, verbose): + self.device = device + self.verbose = verbose + + if verbose: + if 'cpu' in device: + logger = logging.getLogger(__name__) + logger.warning("Detection running on CPU, this may be potentially slow.") + + if 'cpu' not in device and 'cuda' not in device: + if verbose: + logger.error("Expected values for device are: {cpu, cuda} but got: %s", device) + raise ValueError + + def detect_from_image(self, tensor_or_path): + """Detects faces in a given image. + + This function detects the faces present in a provided BGR(usually) + image. The input can be either the image itself or the path to it. + + Arguments: + tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path + to an image or the image itself. + + Example:: + + >>> path_to_image = 'data/image_01.jpg' + ... detected_faces = detect_from_image(path_to_image) + [A list of bounding boxes (x1, y1, x2, y2)] + >>> image = cv2.imread(path_to_image) + ... detected_faces = detect_from_image(image) + [A list of bounding boxes (x1, y1, x2, y2)] + + """ + raise NotImplementedError + + def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True): + """Detects faces from all the images present in a given directory. + + Arguments: + path {string} -- a string containing a path that points to the folder containing the images + + Keyword Arguments: + extensions {list} -- list of string containing the extensions to be + consider in the following format: ``.extension_name`` (default: + {['.jpg', '.png']}) recursive {bool} -- option wherever to scan the + folder recursively (default: {False}) show_progress_bar {bool} -- + display a progressbar (default: {True}) + + Example: + >>> directory = 'data' + ... detected_faces = detect_from_directory(directory) + {A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]} + + """ + if self.verbose: + logger = logging.getLogger(__name__) + + if len(extensions) == 0: + if self.verbose: + logger.error("Expected at list one extension, but none was received.") + raise ValueError + + if self.verbose: + logger.info("Constructing the list of images.") + additional_pattern = '/**/*' if recursive else '/*' + files = [] + for extension in extensions: + files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive)) + + if self.verbose: + logger.info("Finished searching for images. %s images found", len(files)) + logger.info("Preparing to run the detection.") + + predictions = {} + for image_path in tqdm(files, disable=not show_progress_bar): + if self.verbose: + logger.info("Running the face detector on image: %s", image_path) + predictions[image_path] = self.detect_from_image(image_path) + + if self.verbose: + logger.info("The detector was successfully run on all %s images", len(files)) + + return predictions + + @property + def reference_scale(self): + raise NotImplementedError + + @property + def reference_x_shift(self): + raise NotImplementedError + + @property + def reference_y_shift(self): + raise NotImplementedError + + @staticmethod + def tensor_or_path_to_ndarray(tensor_or_path, rgb=True): + """Convert path (represented as a string) or torch.tensor to a numpy.ndarray + + Arguments: + tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself + """ + if isinstance(tensor_or_path, str): + return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1] + elif torch.is_tensor(tensor_or_path): + # Call cpu in case its coming from cuda + return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy() + elif isinstance(tensor_or_path, np.ndarray): + return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path + else: + raise TypeError diff --git a/wav2lip/face_detection/detection/sfd/__init__.py b/wav2lip/face_detection/detection/sfd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5a63ecd45658f22e66c171ada751fb33764d4559 --- /dev/null +++ b/wav2lip/face_detection/detection/sfd/__init__.py @@ -0,0 +1 @@ +from .sfd_detector import SFDDetector as FaceDetector \ No newline at end of file diff --git a/wav2lip/face_detection/detection/sfd/__pycache__/__init__.cpython-310.pyc b/wav2lip/face_detection/detection/sfd/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25d53a8f044dfa0a61d84b66e21c81610e3304f6 Binary files /dev/null and b/wav2lip/face_detection/detection/sfd/__pycache__/__init__.cpython-310.pyc differ diff --git a/wav2lip/face_detection/detection/sfd/__pycache__/bbox.cpython-310.pyc b/wav2lip/face_detection/detection/sfd/__pycache__/bbox.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75e46899a33ec992e062b22a1dbbdde7c43dee45 Binary files /dev/null and b/wav2lip/face_detection/detection/sfd/__pycache__/bbox.cpython-310.pyc differ diff --git a/wav2lip/face_detection/detection/sfd/__pycache__/detect.cpython-310.pyc b/wav2lip/face_detection/detection/sfd/__pycache__/detect.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8fc25674e7488b31ab0602dd521260f48df4088 Binary files /dev/null and b/wav2lip/face_detection/detection/sfd/__pycache__/detect.cpython-310.pyc differ diff --git a/wav2lip/face_detection/detection/sfd/__pycache__/net_s3fd.cpython-310.pyc b/wav2lip/face_detection/detection/sfd/__pycache__/net_s3fd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17114ad184261f1117c8d05fad7fbf93168b1651 Binary files /dev/null and b/wav2lip/face_detection/detection/sfd/__pycache__/net_s3fd.cpython-310.pyc differ diff --git a/wav2lip/face_detection/detection/sfd/__pycache__/sfd_detector.cpython-310.pyc b/wav2lip/face_detection/detection/sfd/__pycache__/sfd_detector.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eff96e4a1b9f81f2744efd85ca563d81a1fed411 Binary files /dev/null and b/wav2lip/face_detection/detection/sfd/__pycache__/sfd_detector.cpython-310.pyc differ diff --git a/wav2lip/face_detection/detection/sfd/bbox.py b/wav2lip/face_detection/detection/sfd/bbox.py new file mode 100644 index 0000000000000000000000000000000000000000..4bd7222e5e5f78a51944cbeed3cccbacddc46bed --- /dev/null +++ b/wav2lip/face_detection/detection/sfd/bbox.py @@ -0,0 +1,129 @@ +from __future__ import print_function +import os +import sys +import cv2 +import random +import datetime +import time +import math +import argparse +import numpy as np +import torch + +try: + from iou import IOU +except BaseException: + # IOU cython speedup 10x + def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2): + sa = abs((ax2 - ax1) * (ay2 - ay1)) + sb = abs((bx2 - bx1) * (by2 - by1)) + x1, y1 = max(ax1, bx1), max(ay1, by1) + x2, y2 = min(ax2, bx2), min(ay2, by2) + w = x2 - x1 + h = y2 - y1 + if w < 0 or h < 0: + return 0.0 + else: + return 1.0 * w * h / (sa + sb - w * h) + + +def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh): + xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1 + dx, dy = (xc - axc) / aww, (yc - ayc) / ahh + dw, dh = math.log(ww / aww), math.log(hh / ahh) + return dx, dy, dw, dh + + +def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh): + xc, yc = dx * aww + axc, dy * ahh + ayc + ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh + x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2 + return x1, y1, x2, y2 + + +def nms(dets, thresh): + if 0 == len(dets): + return [] + x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4] + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]]) + xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]]) + + w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1) + ovr = w * h / (areas[i] + areas[order[1:]] - w * h) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep + + +def encode(matched, priors, variances): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + Args: + matched: (tensor) Coords of ground truth for each prior in point-form + Shape: [num_priors, 4]. + priors: (tensor) Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + encoded boxes (tensor), Shape: [num_priors, 4] + """ + + # dist b/t match center and prior's center + g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] + # encode variance + g_cxcy /= (variances[0] * priors[:, 2:]) + # match wh / prior wh + g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] + g_wh = torch.log(g_wh) / variances[1] + # return target for smooth_l1_loss + return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] + + +def decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat(( + priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + +def batch_decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat(( + priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:], + priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2) + boxes[:, :, :2] -= boxes[:, :, 2:] / 2 + boxes[:, :, 2:] += boxes[:, :, :2] + return boxes diff --git a/wav2lip/face_detection/detection/sfd/detect.py b/wav2lip/face_detection/detection/sfd/detect.py new file mode 100644 index 0000000000000000000000000000000000000000..efef6273adf317bc17f3dd0f02423c0701ca218e --- /dev/null +++ b/wav2lip/face_detection/detection/sfd/detect.py @@ -0,0 +1,112 @@ +import torch +import torch.nn.functional as F + +import os +import sys +import cv2 +import random +import datetime +import math +import argparse +import numpy as np + +import scipy.io as sio +import zipfile +from .net_s3fd import s3fd +from .bbox import * + + +def detect(net, img, device): + img = img - np.array([104, 117, 123]) + img = img.transpose(2, 0, 1) + img = img.reshape((1,) + img.shape) + + if 'cuda' in device: + torch.backends.cudnn.benchmark = True + + img = torch.from_numpy(img).float().to(device) + BB, CC, HH, WW = img.size() + with torch.no_grad(): + olist = net(img) + + bboxlist = [] + for i in range(len(olist) // 2): + olist[i * 2] = F.softmax(olist[i * 2], dim=1) + olist = [oelem.data.cpu() for oelem in olist] + for i in range(len(olist) // 2): + ocls, oreg = olist[i * 2], olist[i * 2 + 1] + FB, FC, FH, FW = ocls.size() # feature map size + stride = 2**(i + 2) # 4,8,16,32,64,128 + anchor = stride * 4 + poss = zip(*np.where(ocls[:, 1, :, :] > 0.05)) + for Iindex, hindex, windex in poss: + axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride + score = ocls[0, 1, hindex, windex] + loc = oreg[0, :, hindex, windex].contiguous().view(1, 4) + priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]) + variances = [0.1, 0.2] + box = decode(loc, priors, variances) + x1, y1, x2, y2 = box[0] * 1.0 + # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1) + bboxlist.append([x1, y1, x2, y2, score]) + bboxlist = np.array(bboxlist) + if 0 == len(bboxlist): + bboxlist = np.zeros((1, 5)) + + return bboxlist + +def batch_detect(net, imgs, device): + imgs = imgs - np.array([104, 117, 123]) + imgs = imgs.transpose(0, 3, 1, 2) + + if 'cuda' in device: + torch.backends.cudnn.benchmark = True + + imgs = torch.from_numpy(imgs).float().to(device) + BB, CC, HH, WW = imgs.size() + with torch.no_grad(): + olist = net(imgs) + + bboxlist = [] + for i in range(len(olist) // 2): + olist[i * 2] = F.softmax(olist[i * 2], dim=1) + olist = [oelem.data.cpu() for oelem in olist] + for i in range(len(olist) // 2): + ocls, oreg = olist[i * 2], olist[i * 2 + 1] + FB, FC, FH, FW = ocls.size() # feature map size + stride = 2**(i + 2) # 4,8,16,32,64,128 + anchor = stride * 4 + poss = zip(*np.where(ocls[:, 1, :, :] > 0.05)) + for Iindex, hindex, windex in poss: + axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride + score = ocls[:, 1, hindex, windex] + loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4) + priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4) + variances = [0.1, 0.2] + box = batch_decode(loc, priors, variances) + box = box[:, 0] * 1.0 + # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1) + bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy()) + bboxlist = np.array(bboxlist) + if 0 == len(bboxlist): + bboxlist = np.zeros((1, BB, 5)) + + return bboxlist + +def flip_detect(net, img, device): + img = cv2.flip(img, 1) + b = detect(net, img, device) + + bboxlist = np.zeros(b.shape) + bboxlist[:, 0] = img.shape[1] - b[:, 2] + bboxlist[:, 1] = b[:, 1] + bboxlist[:, 2] = img.shape[1] - b[:, 0] + bboxlist[:, 3] = b[:, 3] + bboxlist[:, 4] = b[:, 4] + return bboxlist + + +def pts_to_bb(pts): + min_x, min_y = np.min(pts, axis=0) + max_x, max_y = np.max(pts, axis=0) + return np.array([min_x, min_y, max_x, max_y]) diff --git a/wav2lip/face_detection/detection/sfd/net_s3fd.py b/wav2lip/face_detection/detection/sfd/net_s3fd.py new file mode 100644 index 0000000000000000000000000000000000000000..fc64313c277ab594d0257585c70f147606693452 --- /dev/null +++ b/wav2lip/face_detection/detection/sfd/net_s3fd.py @@ -0,0 +1,129 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class L2Norm(nn.Module): + def __init__(self, n_channels, scale=1.0): + super(L2Norm, self).__init__() + self.n_channels = n_channels + self.scale = scale + self.eps = 1e-10 + self.weight = nn.Parameter(torch.Tensor(self.n_channels)) + self.weight.data *= 0.0 + self.weight.data += self.scale + + def forward(self, x): + norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps + x = x / norm * self.weight.view(1, -1, 1, 1) + return x + + +class s3fd(nn.Module): + def __init__(self): + super(s3fd, self).__init__() + self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) + self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + + self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) + self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) + + self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) + self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + + self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) + self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3) + self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0) + + self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0) + self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1) + + self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0) + self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1) + + self.conv3_3_norm = L2Norm(256, scale=10) + self.conv4_3_norm = L2Norm(512, scale=8) + self.conv5_3_norm = L2Norm(512, scale=5) + + self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) + self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) + self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) + self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) + self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) + self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) + + self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1) + self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1) + self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) + self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) + self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1) + self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + h = F.relu(self.conv1_1(x)) + h = F.relu(self.conv1_2(h)) + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv2_1(h)) + h = F.relu(self.conv2_2(h)) + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv3_1(h)) + h = F.relu(self.conv3_2(h)) + h = F.relu(self.conv3_3(h)) + f3_3 = h + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv4_1(h)) + h = F.relu(self.conv4_2(h)) + h = F.relu(self.conv4_3(h)) + f4_3 = h + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv5_1(h)) + h = F.relu(self.conv5_2(h)) + h = F.relu(self.conv5_3(h)) + f5_3 = h + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.fc6(h)) + h = F.relu(self.fc7(h)) + ffc7 = h + h = F.relu(self.conv6_1(h)) + h = F.relu(self.conv6_2(h)) + f6_2 = h + h = F.relu(self.conv7_1(h)) + h = F.relu(self.conv7_2(h)) + f7_2 = h + + f3_3 = self.conv3_3_norm(f3_3) + f4_3 = self.conv4_3_norm(f4_3) + f5_3 = self.conv5_3_norm(f5_3) + + cls1 = self.conv3_3_norm_mbox_conf(f3_3) + reg1 = self.conv3_3_norm_mbox_loc(f3_3) + cls2 = self.conv4_3_norm_mbox_conf(f4_3) + reg2 = self.conv4_3_norm_mbox_loc(f4_3) + cls3 = self.conv5_3_norm_mbox_conf(f5_3) + reg3 = self.conv5_3_norm_mbox_loc(f5_3) + cls4 = self.fc7_mbox_conf(ffc7) + reg4 = self.fc7_mbox_loc(ffc7) + cls5 = self.conv6_2_mbox_conf(f6_2) + reg5 = self.conv6_2_mbox_loc(f6_2) + cls6 = self.conv7_2_mbox_conf(f7_2) + reg6 = self.conv7_2_mbox_loc(f7_2) + + # max-out background label + chunk = torch.chunk(cls1, 4, 1) + bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2]) + cls1 = torch.cat([bmax, chunk[3]], dim=1) + + return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6] diff --git a/wav2lip/face_detection/detection/sfd/sfd_detector.py b/wav2lip/face_detection/detection/sfd/sfd_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..8fbce15253251d403754ab4348f93ae85a6ba2fb --- /dev/null +++ b/wav2lip/face_detection/detection/sfd/sfd_detector.py @@ -0,0 +1,59 @@ +import os +import cv2 +from torch.utils.model_zoo import load_url + +from ..core import FaceDetector + +from .net_s3fd import s3fd +from .bbox import * +from .detect import * + +models_urls = { + 's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth', +} + + +class SFDDetector(FaceDetector): + def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False): + super(SFDDetector, self).__init__(device, verbose) + + # Initialise the face detector + if not os.path.isfile(path_to_detector): + model_weights = load_url(models_urls['s3fd']) + else: + model_weights = torch.load(path_to_detector) + + self.face_detector = s3fd() + self.face_detector.load_state_dict(model_weights) + self.face_detector.to(device) + self.face_detector.eval() + + def detect_from_image(self, tensor_or_path): + image = self.tensor_or_path_to_ndarray(tensor_or_path) + + bboxlist = detect(self.face_detector, image, device=self.device) + keep = nms(bboxlist, 0.3) + bboxlist = bboxlist[keep, :] + bboxlist = [x for x in bboxlist if x[-1] > 0.5] + + return bboxlist + + def detect_from_batch(self, images): + bboxlists = batch_detect(self.face_detector, images, device=self.device) + keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])] + bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)] + bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists] + + return bboxlists + + @property + def reference_scale(self): + return 195 + + @property + def reference_x_shift(self): + return 0 + + @property + def reference_y_shift(self): + return 0 diff --git a/wav2lip/face_detection/models.py b/wav2lip/face_detection/models.py new file mode 100644 index 0000000000000000000000000000000000000000..ee2dde32bdf72c25a4600e48efa73ffc0d4a3893 --- /dev/null +++ b/wav2lip/face_detection/models.py @@ -0,0 +1,261 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + + +def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, + stride=strd, padding=padding, bias=bias) + + +class ConvBlock(nn.Module): + def __init__(self, in_planes, out_planes): + super(ConvBlock, self).__init__() + self.bn1 = nn.BatchNorm2d(in_planes) + self.conv1 = conv3x3(in_planes, int(out_planes / 2)) + self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) + self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4)) + self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) + self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4)) + + if in_planes != out_planes: + self.downsample = nn.Sequential( + nn.BatchNorm2d(in_planes), + nn.ReLU(True), + nn.Conv2d(in_planes, out_planes, + kernel_size=1, stride=1, bias=False), + ) + else: + self.downsample = None + + def forward(self, x): + residual = x + + out1 = self.bn1(x) + out1 = F.relu(out1, True) + out1 = self.conv1(out1) + + out2 = self.bn2(out1) + out2 = F.relu(out2, True) + out2 = self.conv2(out2) + + out3 = self.bn3(out2) + out3 = F.relu(out3, True) + out3 = self.conv3(out3) + + out3 = torch.cat((out1, out2, out3), 1) + + if self.downsample is not None: + residual = self.downsample(residual) + + out3 += residual + + return out3 + + +class Bottleneck(nn.Module): + + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HourGlass(nn.Module): + def __init__(self, num_modules, depth, num_features): + super(HourGlass, self).__init__() + self.num_modules = num_modules + self.depth = depth + self.features = num_features + + self._generate_network(self.depth) + + def _generate_network(self, level): + self.add_module('b1_' + str(level), ConvBlock(self.features, self.features)) + + self.add_module('b2_' + str(level), ConvBlock(self.features, self.features)) + + if level > 1: + self._generate_network(level - 1) + else: + self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features)) + + self.add_module('b3_' + str(level), ConvBlock(self.features, self.features)) + + def _forward(self, level, inp): + # Upper branch + up1 = inp + up1 = self._modules['b1_' + str(level)](up1) + + # Lower branch + low1 = F.avg_pool2d(inp, 2, stride=2) + low1 = self._modules['b2_' + str(level)](low1) + + if level > 1: + low2 = self._forward(level - 1, low1) + else: + low2 = low1 + low2 = self._modules['b2_plus_' + str(level)](low2) + + low3 = low2 + low3 = self._modules['b3_' + str(level)](low3) + + up2 = F.interpolate(low3, scale_factor=2, mode='nearest') + + return up1 + up2 + + def forward(self, x): + return self._forward(self.depth, x) + + +class FAN(nn.Module): + + def __init__(self, num_modules=1): + super(FAN, self).__init__() + self.num_modules = num_modules + + # Base part + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.bn1 = nn.BatchNorm2d(64) + self.conv2 = ConvBlock(64, 128) + self.conv3 = ConvBlock(128, 128) + self.conv4 = ConvBlock(128, 256) + + # Stacking part + for hg_module in range(self.num_modules): + self.add_module('m' + str(hg_module), HourGlass(1, 4, 256)) + self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256)) + self.add_module('conv_last' + str(hg_module), + nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) + self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256)) + self.add_module('l' + str(hg_module), nn.Conv2d(256, + 68, kernel_size=1, stride=1, padding=0)) + + if hg_module < self.num_modules - 1: + self.add_module( + 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) + self.add_module('al' + str(hg_module), nn.Conv2d(68, + 256, kernel_size=1, stride=1, padding=0)) + + def forward(self, x): + x = F.relu(self.bn1(self.conv1(x)), True) + x = F.avg_pool2d(self.conv2(x), 2, stride=2) + x = self.conv3(x) + x = self.conv4(x) + + previous = x + + outputs = [] + for i in range(self.num_modules): + hg = self._modules['m' + str(i)](previous) + + ll = hg + ll = self._modules['top_m_' + str(i)](ll) + + ll = F.relu(self._modules['bn_end' + str(i)] + (self._modules['conv_last' + str(i)](ll)), True) + + # Predict heatmaps + tmp_out = self._modules['l' + str(i)](ll) + outputs.append(tmp_out) + + if i < self.num_modules - 1: + ll = self._modules['bl' + str(i)](ll) + tmp_out_ = self._modules['al' + str(i)](tmp_out) + previous = previous + ll + tmp_out_ + + return outputs + + +class ResNetDepth(nn.Module): + + def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68): + self.inplanes = 64 + super(ResNetDepth, self).__init__() + self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x diff --git a/wav2lip/face_detection/utils.py b/wav2lip/face_detection/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3dc4cf3e328efaa227cbcfdd969e1056688adad5 --- /dev/null +++ b/wav2lip/face_detection/utils.py @@ -0,0 +1,313 @@ +from __future__ import print_function +import os +import sys +import time +import torch +import math +import numpy as np +import cv2 + + +def _gaussian( + size=3, sigma=0.25, amplitude=1, normalize=False, width=None, + height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5, + mean_vert=0.5): + # handle some defaults + if width is None: + width = size + if height is None: + height = size + if sigma_horz is None: + sigma_horz = sigma + if sigma_vert is None: + sigma_vert = sigma + center_x = mean_horz * width + 0.5 + center_y = mean_vert * height + 0.5 + gauss = np.empty((height, width), dtype=np.float32) + # generate kernel + for i in range(height): + for j in range(width): + gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / ( + sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0)) + if normalize: + gauss = gauss / np.sum(gauss) + return gauss + + +def draw_gaussian(image, point, sigma): + # Check if the gaussian is inside + ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)] + br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)] + if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1): + return image + size = 6 * sigma + 1 + g = _gaussian(size) + g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))] + g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))] + img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))] + img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))] + assert (g_x[0] > 0 and g_y[1] > 0) + image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1] + ] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]] + image[image > 1] = 1 + return image + + +def transform(point, center, scale, resolution, invert=False): + """Generate and affine transformation matrix. + + Given a set of points, a center, a scale and a targer resolution, the + function generates and affine transformation matrix. If invert is ``True`` + it will produce the inverse transformation. + + Arguments: + point {torch.tensor} -- the input 2D point + center {torch.tensor or numpy.array} -- the center around which to perform the transformations + scale {float} -- the scale of the face/object + resolution {float} -- the output resolution + + Keyword Arguments: + invert {bool} -- define wherever the function should produce the direct or the + inverse transformation matrix (default: {False}) + """ + _pt = torch.ones(3) + _pt[0] = point[0] + _pt[1] = point[1] + + h = 200.0 * scale + t = torch.eye(3) + t[0, 0] = resolution / h + t[1, 1] = resolution / h + t[0, 2] = resolution * (-center[0] / h + 0.5) + t[1, 2] = resolution * (-center[1] / h + 0.5) + + if invert: + t = torch.inverse(t) + + new_point = (torch.matmul(t, _pt))[0:2] + + return new_point.int() + + +def crop(image, center, scale, resolution=256.0): + """Center crops an image or set of heatmaps + + Arguments: + image {numpy.array} -- an rgb image + center {numpy.array} -- the center of the object, usually the same as of the bounding box + scale {float} -- scale of the face + + Keyword Arguments: + resolution {float} -- the size of the output cropped image (default: {256.0}) + + Returns: + [type] -- [description] + """ # Crop around the center point + """ Crops the image around the center. Input is expected to be an np.ndarray """ + ul = transform([1, 1], center, scale, resolution, True) + br = transform([resolution, resolution], center, scale, resolution, True) + # pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0) + if image.ndim > 2: + newDim = np.array([br[1] - ul[1], br[0] - ul[0], + image.shape[2]], dtype=np.int32) + newImg = np.zeros(newDim, dtype=np.uint8) + else: + newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int) + newImg = np.zeros(newDim, dtype=np.uint8) + ht = image.shape[0] + wd = image.shape[1] + newX = np.array( + [max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32) + newY = np.array( + [max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32) + oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32) + oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32) + newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1] + ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :] + newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)), + interpolation=cv2.INTER_LINEAR) + return newImg + + +def get_preds_fromhm(hm, center=None, scale=None): + """Obtain (x,y) coordinates given a set of N heatmaps. If the center + and the scale is provided the function will return the points also in + the original coordinate frame. + + Arguments: + hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H] + + Keyword Arguments: + center {torch.tensor} -- the center of the bounding box (default: {None}) + scale {float} -- face scale (default: {None}) + """ + max, idx = torch.max( + hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2) + idx += 1 + preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float() + preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1) + preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1) + + for i in range(preds.size(0)): + for j in range(preds.size(1)): + hm_ = hm[i, j, :] + pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1 + if pX > 0 and pX < 63 and pY > 0 and pY < 63: + diff = torch.FloatTensor( + [hm_[pY, pX + 1] - hm_[pY, pX - 1], + hm_[pY + 1, pX] - hm_[pY - 1, pX]]) + preds[i, j].add_(diff.sign_().mul_(.25)) + + preds.add_(-.5) + + preds_orig = torch.zeros(preds.size()) + if center is not None and scale is not None: + for i in range(hm.size(0)): + for j in range(hm.size(1)): + preds_orig[i, j] = transform( + preds[i, j], center, scale, hm.size(2), True) + + return preds, preds_orig + +def get_preds_fromhm_batch(hm, centers=None, scales=None): + """Obtain (x,y) coordinates given a set of N heatmaps. If the centers + and the scales is provided the function will return the points also in + the original coordinate frame. + + Arguments: + hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H] + + Keyword Arguments: + centers {torch.tensor} -- the centers of the bounding box (default: {None}) + scales {float} -- face scales (default: {None}) + """ + max, idx = torch.max( + hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2) + idx += 1 + preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float() + preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1) + preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1) + + for i in range(preds.size(0)): + for j in range(preds.size(1)): + hm_ = hm[i, j, :] + pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1 + if pX > 0 and pX < 63 and pY > 0 and pY < 63: + diff = torch.FloatTensor( + [hm_[pY, pX + 1] - hm_[pY, pX - 1], + hm_[pY + 1, pX] - hm_[pY - 1, pX]]) + preds[i, j].add_(diff.sign_().mul_(.25)) + + preds.add_(-.5) + + preds_orig = torch.zeros(preds.size()) + if centers is not None and scales is not None: + for i in range(hm.size(0)): + for j in range(hm.size(1)): + preds_orig[i, j] = transform( + preds[i, j], centers[i], scales[i], hm.size(2), True) + + return preds, preds_orig + +def shuffle_lr(parts, pairs=None): + """Shuffle the points left-right according to the axis of symmetry + of the object. + + Arguments: + parts {torch.tensor} -- a 3D or 4D object containing the + heatmaps. + + Keyword Arguments: + pairs {list of integers} -- [order of the flipped points] (default: {None}) + """ + if pairs is None: + pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, + 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35, + 34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41, + 40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63, + 62, 61, 60, 67, 66, 65] + if parts.ndimension() == 3: + parts = parts[pairs, ...] + else: + parts = parts[:, pairs, ...] + + return parts + + +def flip(tensor, is_label=False): + """Flip an image or a set of heatmaps left-right + + Arguments: + tensor {numpy.array or torch.tensor} -- [the input image or heatmaps] + + Keyword Arguments: + is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False}) + """ + if not torch.is_tensor(tensor): + tensor = torch.from_numpy(tensor) + + if is_label: + tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1) + else: + tensor = tensor.flip(tensor.ndimension() - 1) + + return tensor + +# From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py) + + +def appdata_dir(appname=None, roaming=False): + """ appdata_dir(appname=None, roaming=False) + + Get the path to the application directory, where applications are allowed + to write user specific files (e.g. configurations). For non-user specific + data, consider using common_appdata_dir(). + If appname is given, a subdir is appended (and created if necessary). + If roaming is True, will prefer a roaming directory (Windows Vista/7). + """ + + # Define default user directory + userDir = os.getenv('FACEALIGNMENT_USERDIR', None) + if userDir is None: + userDir = os.path.expanduser('~') + if not os.path.isdir(userDir): # pragma: no cover + userDir = '/var/tmp' # issue #54 + + # Get system app data dir + path = None + if sys.platform.startswith('win'): + path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA') + path = (path2 or path1) if roaming else (path1 or path2) + elif sys.platform.startswith('darwin'): + path = os.path.join(userDir, 'Library', 'Application Support') + # On Linux and as fallback + if not (path and os.path.isdir(path)): + path = userDir + + # Maybe we should store things local to the executable (in case of a + # portable distro or a frozen application that wants to be portable) + prefix = sys.prefix + if getattr(sys, 'frozen', None): + prefix = os.path.abspath(os.path.dirname(sys.executable)) + for reldir in ('settings', '../settings'): + localpath = os.path.abspath(os.path.join(prefix, reldir)) + if os.path.isdir(localpath): # pragma: no cover + try: + open(os.path.join(localpath, 'test.write'), 'wb').close() + os.remove(os.path.join(localpath, 'test.write')) + except IOError: + pass # We cannot write in this directory + else: + path = localpath + break + + # Get path specific for this app + if appname: + if path == userDir: + appname = '.' + appname.lstrip('.') # Make it a hidden directory + path = os.path.join(path, appname) + if not os.path.isdir(path): # pragma: no cover + os.mkdir(path) + + # Done + return path diff --git a/wav2lip/hparams.py b/wav2lip/hparams.py new file mode 100644 index 0000000000000000000000000000000000000000..46778ef855fd36663372ae5072f9e6d253f282a1 --- /dev/null +++ b/wav2lip/hparams.py @@ -0,0 +1,100 @@ +from glob import glob +import os + +def get_image_list(data_root, split): + filelist = [] + + with open('filelists/{}.txt'.format(split)) as f: + for line in f: + line = line.strip() + if ' ' in line: line = line.split()[0] + filelist.append(os.path.join(data_root, line)) + + return filelist + +class HParams: + def __init__(self, **kwargs): + self.data = {} + + for key, value in kwargs.items(): + self.data[key] = value + + def __getattr__(self, key): + if key not in self.data: + raise AttributeError("'HParams' object has no attribute %s" % key) + return self.data[key] + + def set_hparam(self, key, value): + self.data[key] = value + + +# Default hyperparameters +hparams = HParams( + num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality + # network + rescale=True, # Whether to rescale audio prior to preprocessing + rescaling_max=0.9, # Rescaling value + + # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction + # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder + # Does not work if n_ffit is not multiple of hop_size!! + use_lws=False, + + n_fft=800, # Extra window size is filled with 0 paddings to match this parameter + hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate) + win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) + sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i ) + + frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5) + + # Mel and Linear spectrograms normalization/scaling and clipping + signal_normalization=True, + # Whether to normalize mel spectrograms to some predefined range (following below parameters) + allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True + symmetric_mels=True, + # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, + # faster and cleaner convergence) + max_abs_value=4., + # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not + # be too big to avoid gradient explosion, + # not too small for fast convergence) + # Contribution by @begeekmyfriend + # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude + # levels. Also allows for better G&L phase reconstruction) + preemphasize=True, # whether to apply filter + preemphasis=0.97, # filter coefficient. + + # Limits + min_level_db=-100, + ref_level_db=20, + fmin=55, + # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To + # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) + fmax=7600, # To be increased/reduced depending on data. + + ###################### Our training parameters ################################# + + + batch_size=16, + initial_learning_rate=1e-4, + nepochs=200000000000000000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs + num_workers=16, + checkpoint_interval=3000, + eval_interval=3000, + save_optimizer_state=True, + + syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence. + syncnet_batch_size=64, + syncnet_lr=1e-4, + syncnet_eval_interval=10000, + syncnet_checkpoint_interval=10000, + + disc_wt=0.07, + disc_initial_learning_rate=1e-4, +) + + +def hparams_debug_string(): + values = hparams.values() + hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"] + return "Hyperparameters:\n" + "\n".join(hp) diff --git a/wav2lip/models/__init__.py b/wav2lip/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4374370494b65f10b76c70a2d4f731c238cfa54c --- /dev/null +++ b/wav2lip/models/__init__.py @@ -0,0 +1,2 @@ +from .wav2lip import Wav2Lip, Wav2Lip_disc_qual +from .syncnet import SyncNet_color \ No newline at end of file diff --git a/wav2lip/models/__pycache__/__init__.cpython-310.pyc b/wav2lip/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e74d2b3263d48584ddc1438b6180075ba845d0e6 Binary files /dev/null and b/wav2lip/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/wav2lip/models/__pycache__/conv.cpython-310.pyc b/wav2lip/models/__pycache__/conv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5967103445b45b50c7f1e765d7cc34fa96be3fa7 Binary files /dev/null and b/wav2lip/models/__pycache__/conv.cpython-310.pyc differ diff --git a/wav2lip/models/__pycache__/syncnet.cpython-310.pyc b/wav2lip/models/__pycache__/syncnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5186d94d1a56189797d84a9ca12aa7c4cb2272e Binary files /dev/null and b/wav2lip/models/__pycache__/syncnet.cpython-310.pyc differ diff --git a/wav2lip/models/__pycache__/wav2lip.cpython-310.pyc b/wav2lip/models/__pycache__/wav2lip.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6f3efba7173d0181b20bb1f402c49d430ececf5 Binary files /dev/null and b/wav2lip/models/__pycache__/wav2lip.cpython-310.pyc differ diff --git a/wav2lip/models/conv.py b/wav2lip/models/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..ed83da00cb199e027ef217fd360352d91a7891ff --- /dev/null +++ b/wav2lip/models/conv.py @@ -0,0 +1,44 @@ +import torch +from torch import nn +from torch.nn import functional as F + +class Conv2d(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.Conv2d(cin, cout, kernel_size, stride, padding), + nn.BatchNorm2d(cout) + ) + self.act = nn.ReLU() + self.residual = residual + + def forward(self, x): + out = self.conv_block(x) + if self.residual: + out += x + return self.act(out) + +class nonorm_Conv2d(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.Conv2d(cin, cout, kernel_size, stride, padding), + ) + self.act = nn.LeakyReLU(0.01, inplace=True) + + def forward(self, x): + out = self.conv_block(x) + return self.act(out) + +class Conv2dTranspose(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding), + nn.BatchNorm2d(cout) + ) + self.act = nn.ReLU() + + def forward(self, x): + out = self.conv_block(x) + return self.act(out) diff --git a/wav2lip/models/syncnet.py b/wav2lip/models/syncnet.py new file mode 100644 index 0000000000000000000000000000000000000000..e773cdca675236745a379a776b7c07d7d353f590 --- /dev/null +++ b/wav2lip/models/syncnet.py @@ -0,0 +1,66 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from .conv import Conv2d + +class SyncNet_color(nn.Module): + def __init__(self): + super(SyncNet_color, self).__init__() + + self.face_encoder = nn.Sequential( + Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3), + + Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(64, 128, kernel_size=3, stride=2, padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(128, 256, kernel_size=3, stride=2, padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(256, 512, kernel_size=3, stride=2, padding=1), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(512, 512, kernel_size=3, stride=2, padding=1), + Conv2d(512, 512, kernel_size=3, stride=1, padding=0), + Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) + + self.audio_encoder = nn.Sequential( + Conv2d(1, 32, kernel_size=3, stride=1, padding=1), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(64, 128, kernel_size=3, stride=3, padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(256, 512, kernel_size=3, stride=1, padding=0), + Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) + + def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T) + face_embedding = self.face_encoder(face_sequences) + audio_embedding = self.audio_encoder(audio_sequences) + + audio_embedding = audio_embedding.view(audio_embedding.size(0), -1) + face_embedding = face_embedding.view(face_embedding.size(0), -1) + + audio_embedding = F.normalize(audio_embedding, p=2, dim=1) + face_embedding = F.normalize(face_embedding, p=2, dim=1) + + + return audio_embedding, face_embedding diff --git a/wav2lip/models/wav2lip.py b/wav2lip/models/wav2lip.py new file mode 100644 index 0000000000000000000000000000000000000000..ae5d6919169ec497f0f0815184f5db8ba9108fbd --- /dev/null +++ b/wav2lip/models/wav2lip.py @@ -0,0 +1,184 @@ +import torch +from torch import nn +from torch.nn import functional as F +import math + +from .conv import Conv2dTranspose, Conv2d, nonorm_Conv2d + +class Wav2Lip(nn.Module): + def __init__(self): + super(Wav2Lip, self).__init__() + + self.face_encoder_blocks = nn.ModuleList([ + nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3)), # 96,96 + + nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 48,48 + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 24,24 + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 12,12 + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 6,6 + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 3,3 + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), + + nn.Sequential(Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1 + Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),]) + + self.audio_encoder = nn.Sequential( + Conv2d(1, 32, kernel_size=3, stride=1, padding=1), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(64, 128, kernel_size=3, stride=3, padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(256, 512, kernel_size=3, stride=1, padding=0), + Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) + + self.face_decoder_blocks = nn.ModuleList([ + nn.Sequential(Conv2d(512, 512, kernel_size=1, stride=1, padding=0),), + + nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=1, padding=0), # 3,3 + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), + + nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), # 6, 6 + + nn.Sequential(Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),), # 12, 12 + + nn.Sequential(Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),), # 24, 24 + + nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),), # 48, 48 + + nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),),]) # 96,96 + + self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1), + nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0), + nn.Sigmoid()) + + def forward(self, audio_sequences, face_sequences): + # audio_sequences = (B, T, 1, 80, 16) + B = audio_sequences.size(0) + + input_dim_size = len(face_sequences.size()) + if input_dim_size > 4: + audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) + face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0) + + audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1 + + feats = [] + x = face_sequences + for f in self.face_encoder_blocks: + x = f(x) + feats.append(x) + + x = audio_embedding + for f in self.face_decoder_blocks: + x = f(x) + try: + x = torch.cat((x, feats[-1]), dim=1) + except Exception as e: + print(x.size()) + print(feats[-1].size()) + raise e + + feats.pop() + + x = self.output_block(x) + + if input_dim_size > 4: + x = torch.split(x, B, dim=0) # [(B, C, H, W)] + outputs = torch.stack(x, dim=2) # (B, C, T, H, W) + + else: + outputs = x + + return outputs + +class Wav2Lip_disc_qual(nn.Module): + def __init__(self): + super(Wav2Lip_disc_qual, self).__init__() + + self.face_encoder_blocks = nn.ModuleList([ + nn.Sequential(nonorm_Conv2d(3, 32, kernel_size=7, stride=1, padding=3)), # 48,96 + + nn.Sequential(nonorm_Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=2), # 48,48 + nonorm_Conv2d(64, 64, kernel_size=5, stride=1, padding=2)), + + nn.Sequential(nonorm_Conv2d(64, 128, kernel_size=5, stride=2, padding=2), # 24,24 + nonorm_Conv2d(128, 128, kernel_size=5, stride=1, padding=2)), + + nn.Sequential(nonorm_Conv2d(128, 256, kernel_size=5, stride=2, padding=2), # 12,12 + nonorm_Conv2d(256, 256, kernel_size=5, stride=1, padding=2)), + + nn.Sequential(nonorm_Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 6,6 + nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1)), + + nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1), # 3,3 + nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1),), + + nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1 + nonorm_Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),]) + + self.binary_pred = nn.Sequential(nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid()) + self.label_noise = .0 + + def get_lower_half(self, face_sequences): + return face_sequences[:, :, face_sequences.size(2)//2:] + + def to_2d(self, face_sequences): + B = face_sequences.size(0) + face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0) + return face_sequences + + def perceptual_forward(self, false_face_sequences): + false_face_sequences = self.to_2d(false_face_sequences) + false_face_sequences = self.get_lower_half(false_face_sequences) + + false_feats = false_face_sequences + for f in self.face_encoder_blocks: + false_feats = f(false_feats) + + false_pred_loss = F.binary_cross_entropy(self.binary_pred(false_feats).view(len(false_feats), -1), + torch.ones((len(false_feats), 1)).cuda()) + + return false_pred_loss + + def forward(self, face_sequences): + face_sequences = self.to_2d(face_sequences) + face_sequences = self.get_lower_half(face_sequences) + + x = face_sequences + for f in self.face_encoder_blocks: + x = f(x) + + return self.binary_pred(x).view(len(x), -1) diff --git a/wav2lip/wav2lip_gan.pth b/wav2lip/wav2lip_gan.pth new file mode 100644 index 0000000000000000000000000000000000000000..0b8907735b1aef6adb297f2f6dcbcf8432823de4 --- /dev/null +++ b/wav2lip/wav2lip_gan.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ca9ab7b7b812c0e80a6e70a5977c545a1e8a365a6c49d5e533023c034d7ac3d8 +size 435801865 diff --git a/weights/RealESRGAN_x2.pth b/weights/RealESRGAN_x2.pth new file mode 100644 index 0000000000000000000000000000000000000000..313b87ab9359a04b0f450695b1a01a88edd4ac95 --- /dev/null +++ b/weights/RealESRGAN_x2.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c830d067d54fc767b9543a8432f36d91bc2de313584e8bbfe4ac26a47339e899 +size 67061725 diff --git a/weights/RealESRGAN_x4.pth b/weights/RealESRGAN_x4.pth new file mode 100644 index 0000000000000000000000000000000000000000..0271c7326bedc41232e15142c569e09920451045 --- /dev/null +++ b/weights/RealESRGAN_x4.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aa00f09ad753d88576b21ed977e97d634976377031b178acc3b5b238df463400 +size 67040989