diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..f9ba8cf65f3e3104dd061c178066ec8247811f33 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,9 @@ +# Microsoft Open Source Code of Conduct + +This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). + +Resources: + +- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) +- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) +- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..8764e0011f8e0b937674005354ca957317c23fd4 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,43 @@ +FROM nvidia/cuda:11.1-base-ubuntu20.04 + +RUN apt update && DEBIAN_FRONTEND=noninteractive apt install git bzip2 wget unzip python3-pip python3-dev cmake libgl1-mesa-dev python-is-python3 libgtk2.0-dev -yq +ADD . /app +WORKDIR /app +RUN cd Face_Enhancement/models/networks/ &&\ + git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch &&\ + cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm . &&\ + cd ../../../ + +RUN cd Global/detection_models &&\ + git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch &&\ + cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm . &&\ + cd ../../ + +RUN cd Face_Detection/ &&\ + wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 &&\ + bzip2 -d shape_predictor_68_face_landmarks.dat.bz2 &&\ + cd ../ + +RUN cd Face_Enhancement/ &&\ + wget https://facevc.blob.core.windows.net/zhanbo/old_photo/pretrain/Face_Enhancement/checkpoints.zip &&\ + unzip checkpoints.zip &&\ + cd ../ &&\ + cd Global/ &&\ + wget https://facevc.blob.core.windows.net/zhanbo/old_photo/pretrain/Global/checkpoints.zip &&\ + unzip checkpoints.zip &&\ + rm -f checkpoints.zip &&\ + cd ../ + +RUN pip3 install numpy + +RUN pip3 install dlib + +RUN pip3 install -r requirements.txt + +RUN git clone https://github.com/NVlabs/SPADE.git + +RUN cd SPADE/ && pip3 install -r requirements.txt + +RUN cd .. + +CMD ["python3", "run.py"] diff --git a/Face_Detection/align_warp_back_multiple_dlib.py b/Face_Detection/align_warp_back_multiple_dlib.py new file mode 100644 index 0000000000000000000000000000000000000000..4b82139e4a81201b16fdfe56bc1cdb2b97bac398 --- /dev/null +++ b/Face_Detection/align_warp_back_multiple_dlib.py @@ -0,0 +1,437 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import numpy as np +import skimage.io as io + +# from face_sdk import FaceDetection +import matplotlib.pyplot as plt +from matplotlib.patches import Rectangle +from skimage.transform import SimilarityTransform +from skimage.transform import warp +from PIL import Image, ImageFilter +import torch.nn.functional as F +import torchvision as tv +import torchvision.utils as vutils +import time +import cv2 +import os +from skimage import img_as_ubyte +import json +import argparse +import dlib + + +def calculate_cdf(histogram): + """ + This method calculates the cumulative distribution function + :param array histogram: The values of the histogram + :return: normalized_cdf: The normalized cumulative distribution function + :rtype: array + """ + # Get the cumulative sum of the elements + cdf = histogram.cumsum() + + # Normalize the cdf + normalized_cdf = cdf / float(cdf.max()) + + return normalized_cdf + + +def calculate_lookup(src_cdf, ref_cdf): + """ + This method creates the lookup table + :param array src_cdf: The cdf for the source image + :param array ref_cdf: The cdf for the reference image + :return: lookup_table: The lookup table + :rtype: array + """ + lookup_table = np.zeros(256) + lookup_val = 0 + for src_pixel_val in range(len(src_cdf)): + lookup_val + for ref_pixel_val in range(len(ref_cdf)): + if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]: + lookup_val = ref_pixel_val + break + lookup_table[src_pixel_val] = lookup_val + return lookup_table + + +def match_histograms(src_image, ref_image): + """ + This method matches the source image histogram to the + reference signal + :param image src_image: The original source image + :param image ref_image: The reference image + :return: image_after_matching + :rtype: image (array) + """ + # Split the images into the different color channels + # b means blue, g means green and r means red + src_b, src_g, src_r = cv2.split(src_image) + ref_b, ref_g, ref_r = cv2.split(ref_image) + + # Compute the b, g, and r histograms separately + # The flatten() Numpy method returns a copy of the array c + # collapsed into one dimension. + src_hist_blue, bin_0 = np.histogram(src_b.flatten(), 256, [0, 256]) + src_hist_green, bin_1 = np.histogram(src_g.flatten(), 256, [0, 256]) + src_hist_red, bin_2 = np.histogram(src_r.flatten(), 256, [0, 256]) + ref_hist_blue, bin_3 = np.histogram(ref_b.flatten(), 256, [0, 256]) + ref_hist_green, bin_4 = np.histogram(ref_g.flatten(), 256, [0, 256]) + ref_hist_red, bin_5 = np.histogram(ref_r.flatten(), 256, [0, 256]) + + # Compute the normalized cdf for the source and reference image + src_cdf_blue = calculate_cdf(src_hist_blue) + src_cdf_green = calculate_cdf(src_hist_green) + src_cdf_red = calculate_cdf(src_hist_red) + ref_cdf_blue = calculate_cdf(ref_hist_blue) + ref_cdf_green = calculate_cdf(ref_hist_green) + ref_cdf_red = calculate_cdf(ref_hist_red) + + # Make a separate lookup table for each color + blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue) + green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green) + red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red) + + # Use the lookup function to transform the colors of the original + # source image + blue_after_transform = cv2.LUT(src_b, blue_lookup_table) + green_after_transform = cv2.LUT(src_g, green_lookup_table) + red_after_transform = cv2.LUT(src_r, red_lookup_table) + + # Put the image back together + image_after_matching = cv2.merge([blue_after_transform, green_after_transform, red_after_transform]) + image_after_matching = cv2.convertScaleAbs(image_after_matching) + + return image_after_matching + + +def _standard_face_pts(): + pts = ( + np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0 + - 1.0 + ) + + return np.reshape(pts, (5, 2)) + + +def _origin_face_pts(): + pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) + + return np.reshape(pts, (5, 2)) + + +def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0): + + std_pts = _standard_face_pts() # [-1,1] + target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0 + + # print(target_pts) + + h, w, c = img.shape + if normalize == True: + landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0 + landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0 + + # print(landmark) + + affine = SimilarityTransform() + + affine.estimate(target_pts, landmark) + + return affine + + +def compute_inverse_transformation_matrix(img, landmark, normalize, target_face_scale=1.0): + + std_pts = _standard_face_pts() # [-1,1] + target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0 + + # print(target_pts) + + h, w, c = img.shape + if normalize == True: + landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0 + landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0 + + # print(landmark) + + affine = SimilarityTransform() + + affine.estimate(landmark, target_pts) + + return affine + + +def show_detection(image, box, landmark): + plt.imshow(image) + print(box[2] - box[0]) + plt.gca().add_patch( + Rectangle( + (box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor="r", facecolor="none" + ) + ) + plt.scatter(landmark[0][0], landmark[0][1]) + plt.scatter(landmark[1][0], landmark[1][1]) + plt.scatter(landmark[2][0], landmark[2][1]) + plt.scatter(landmark[3][0], landmark[3][1]) + plt.scatter(landmark[4][0], landmark[4][1]) + plt.show() + + +def affine2theta(affine, input_w, input_h, target_w, target_h): + # param = np.linalg.inv(affine) + param = affine + theta = np.zeros([2, 3]) + theta[0, 0] = param[0, 0] * input_h / target_h + theta[0, 1] = param[0, 1] * input_w / target_h + theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1 + theta[1, 0] = param[1, 0] * input_h / target_w + theta[1, 1] = param[1, 1] * input_w / target_w + theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1 + return theta + + +def blur_blending(im1, im2, mask): + + mask *= 255.0 + + kernel = np.ones((10, 10), np.uint8) + mask = cv2.erode(mask, kernel, iterations=1) + + mask = Image.fromarray(mask.astype("uint8")).convert("L") + im1 = Image.fromarray(im1.astype("uint8")) + im2 = Image.fromarray(im2.astype("uint8")) + + mask_blur = mask.filter(ImageFilter.GaussianBlur(20)) + im = Image.composite(im1, im2, mask) + + im = Image.composite(im, im2, mask_blur) + + return np.array(im) / 255.0 + + +def blur_blending_cv2(im1, im2, mask): + + mask *= 255.0 + + kernel = np.ones((9, 9), np.uint8) + mask = cv2.erode(mask, kernel, iterations=3) + + mask_blur = cv2.GaussianBlur(mask, (25, 25), 0) + mask_blur /= 255.0 + + im = im1 * mask_blur + (1 - mask_blur) * im2 + + im /= 255.0 + im = np.clip(im, 0.0, 1.0) + + return im + + +# def Poisson_blending(im1,im2,mask): + + +# Image.composite( +def Poisson_blending(im1, im2, mask): + + # mask=1-mask + mask *= 255 + kernel = np.ones((10, 10), np.uint8) + mask = cv2.erode(mask, kernel, iterations=1) + mask /= 255 + mask = 1 - mask + mask *= 255 + + mask = mask[:, :, 0] + width, height, channels = im1.shape + center = (int(height / 2), int(width / 2)) + result = cv2.seamlessClone( + im2.astype("uint8"), im1.astype("uint8"), mask.astype("uint8"), center, cv2.MIXED_CLONE + ) + + return result / 255.0 + + +def Poisson_B(im1, im2, mask, center): + + mask *= 255 + + result = cv2.seamlessClone( + im2.astype("uint8"), im1.astype("uint8"), mask.astype("uint8"), center, cv2.NORMAL_CLONE + ) + + return result / 255 + + +def seamless_clone(old_face, new_face, raw_mask): + + height, width, _ = old_face.shape + height = height // 2 + width = width // 2 + + y_indices, x_indices, _ = np.nonzero(raw_mask) + y_crop = slice(np.min(y_indices), np.max(y_indices)) + x_crop = slice(np.min(x_indices), np.max(x_indices)) + y_center = int(np.rint((np.max(y_indices) + np.min(y_indices)) / 2 + height)) + x_center = int(np.rint((np.max(x_indices) + np.min(x_indices)) / 2 + width)) + + insertion = np.rint(new_face[y_crop, x_crop] * 255.0).astype("uint8") + insertion_mask = np.rint(raw_mask[y_crop, x_crop] * 255.0).astype("uint8") + insertion_mask[insertion_mask != 0] = 255 + prior = np.rint(np.pad(old_face * 255.0, ((height, height), (width, width), (0, 0)), "constant")).astype( + "uint8" + ) + # if np.sum(insertion_mask) == 0: + n_mask = insertion_mask[1:-1, 1:-1, :] + n_mask = cv2.copyMakeBorder(n_mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 0) + print(n_mask.shape) + x, y, w, h = cv2.boundingRect(n_mask[:, :, 0]) + if w < 4 or h < 4: + blended = prior + else: + blended = cv2.seamlessClone( + insertion, # pylint: disable=no-member + prior, + insertion_mask, + (x_center, y_center), + cv2.NORMAL_CLONE, + ) # pylint: disable=no-member + + blended = blended[height:-height, width:-width] + + return blended.astype("float32") / 255.0 + + +def get_landmark(face_landmarks, id): + part = face_landmarks.part(id) + x = part.x + y = part.y + + return (x, y) + + +def search(face_landmarks): + + x1, y1 = get_landmark(face_landmarks, 36) + x2, y2 = get_landmark(face_landmarks, 39) + x3, y3 = get_landmark(face_landmarks, 42) + x4, y4 = get_landmark(face_landmarks, 45) + + x_nose, y_nose = get_landmark(face_landmarks, 30) + + x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48) + x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54) + + x_left_eye = int((x1 + x2) / 2) + y_left_eye = int((y1 + y2) / 2) + x_right_eye = int((x3 + x4) / 2) + y_right_eye = int((y3 + y4) / 2) + + results = np.array( + [ + [x_left_eye, y_left_eye], + [x_right_eye, y_right_eye], + [x_nose, y_nose], + [x_left_mouth, y_left_mouth], + [x_right_mouth, y_right_mouth], + ] + ) + + return results + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("--origin_url", type=str, default="./", help="origin images") + parser.add_argument("--replace_url", type=str, default="./", help="restored faces") + parser.add_argument("--save_url", type=str, default="./save") + opts = parser.parse_args() + + origin_url = opts.origin_url + replace_url = opts.replace_url + save_url = opts.save_url + + if not os.path.exists(save_url): + os.makedirs(save_url) + + face_detector = dlib.get_frontal_face_detector() + landmark_locator = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") + + count = 0 + + for x in os.listdir(origin_url): + img_url = os.path.join(origin_url, x) + pil_img = Image.open(img_url).convert("RGB") + + origin_width, origin_height = pil_img.size + image = np.array(pil_img) + + start = time.time() + faces = face_detector(image) + done = time.time() + + if len(faces) == 0: + print("Warning: There is no face in %s" % (x)) + continue + + blended = image + for face_id in range(len(faces)): + + current_face = faces[face_id] + face_landmarks = landmark_locator(image, current_face) + current_fl = search(face_landmarks) + + forward_mask = np.ones_like(image).astype("uint8") + affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3) + aligned_face = warp(image, affine, output_shape=(256, 256, 3), preserve_range=True) + forward_mask = warp( + forward_mask, affine, output_shape=(256, 256, 3), order=0, preserve_range=True + ) + + affine_inverse = affine.inverse + cur_face = aligned_face + if replace_url != "": + + face_name = x[:-4] + "_" + str(face_id + 1) + ".png" + cur_url = os.path.join(replace_url, face_name) + restored_face = Image.open(cur_url).convert("RGB") + restored_face = np.array(restored_face) + cur_face = restored_face + + ## Histogram Color matching + A = cv2.cvtColor(aligned_face.astype("uint8"), cv2.COLOR_RGB2BGR) + B = cv2.cvtColor(cur_face.astype("uint8"), cv2.COLOR_RGB2BGR) + B = match_histograms(B, A) + cur_face = cv2.cvtColor(B.astype("uint8"), cv2.COLOR_BGR2RGB) + + warped_back = warp( + cur_face, + affine_inverse, + output_shape=(origin_height, origin_width, 3), + order=3, + preserve_range=True, + ) + + backward_mask = warp( + forward_mask, + affine_inverse, + output_shape=(origin_height, origin_width, 3), + order=0, + preserve_range=True, + ) ## Nearest neighbour + + blended = blur_blending_cv2(warped_back, blended, backward_mask) + blended *= 255.0 + + io.imsave(os.path.join(save_url, x), img_as_ubyte(blended / 255.0)) + + count += 1 + + if count % 1000 == 0: + print("%d have finished ..." % (count)) + diff --git a/Face_Detection/align_warp_back_multiple_dlib_HR.py b/Face_Detection/align_warp_back_multiple_dlib_HR.py new file mode 100644 index 0000000000000000000000000000000000000000..f3711c968ebeba22f3872b8074b7c89f55a634a1 --- /dev/null +++ b/Face_Detection/align_warp_back_multiple_dlib_HR.py @@ -0,0 +1,437 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import numpy as np +import skimage.io as io + +# from face_sdk import FaceDetection +import matplotlib.pyplot as plt +from matplotlib.patches import Rectangle +from skimage.transform import SimilarityTransform +from skimage.transform import warp +from PIL import Image, ImageFilter +import torch.nn.functional as F +import torchvision as tv +import torchvision.utils as vutils +import time +import cv2 +import os +from skimage import img_as_ubyte +import json +import argparse +import dlib + + +def calculate_cdf(histogram): + """ + This method calculates the cumulative distribution function + :param array histogram: The values of the histogram + :return: normalized_cdf: The normalized cumulative distribution function + :rtype: array + """ + # Get the cumulative sum of the elements + cdf = histogram.cumsum() + + # Normalize the cdf + normalized_cdf = cdf / float(cdf.max()) + + return normalized_cdf + + +def calculate_lookup(src_cdf, ref_cdf): + """ + This method creates the lookup table + :param array src_cdf: The cdf for the source image + :param array ref_cdf: The cdf for the reference image + :return: lookup_table: The lookup table + :rtype: array + """ + lookup_table = np.zeros(256) + lookup_val = 0 + for src_pixel_val in range(len(src_cdf)): + lookup_val + for ref_pixel_val in range(len(ref_cdf)): + if ref_cdf[ref_pixel_val] >= src_cdf[src_pixel_val]: + lookup_val = ref_pixel_val + break + lookup_table[src_pixel_val] = lookup_val + return lookup_table + + +def match_histograms(src_image, ref_image): + """ + This method matches the source image histogram to the + reference signal + :param image src_image: The original source image + :param image ref_image: The reference image + :return: image_after_matching + :rtype: image (array) + """ + # Split the images into the different color channels + # b means blue, g means green and r means red + src_b, src_g, src_r = cv2.split(src_image) + ref_b, ref_g, ref_r = cv2.split(ref_image) + + # Compute the b, g, and r histograms separately + # The flatten() Numpy method returns a copy of the array c + # collapsed into one dimension. + src_hist_blue, bin_0 = np.histogram(src_b.flatten(), 256, [0, 256]) + src_hist_green, bin_1 = np.histogram(src_g.flatten(), 256, [0, 256]) + src_hist_red, bin_2 = np.histogram(src_r.flatten(), 256, [0, 256]) + ref_hist_blue, bin_3 = np.histogram(ref_b.flatten(), 256, [0, 256]) + ref_hist_green, bin_4 = np.histogram(ref_g.flatten(), 256, [0, 256]) + ref_hist_red, bin_5 = np.histogram(ref_r.flatten(), 256, [0, 256]) + + # Compute the normalized cdf for the source and reference image + src_cdf_blue = calculate_cdf(src_hist_blue) + src_cdf_green = calculate_cdf(src_hist_green) + src_cdf_red = calculate_cdf(src_hist_red) + ref_cdf_blue = calculate_cdf(ref_hist_blue) + ref_cdf_green = calculate_cdf(ref_hist_green) + ref_cdf_red = calculate_cdf(ref_hist_red) + + # Make a separate lookup table for each color + blue_lookup_table = calculate_lookup(src_cdf_blue, ref_cdf_blue) + green_lookup_table = calculate_lookup(src_cdf_green, ref_cdf_green) + red_lookup_table = calculate_lookup(src_cdf_red, ref_cdf_red) + + # Use the lookup function to transform the colors of the original + # source image + blue_after_transform = cv2.LUT(src_b, blue_lookup_table) + green_after_transform = cv2.LUT(src_g, green_lookup_table) + red_after_transform = cv2.LUT(src_r, red_lookup_table) + + # Put the image back together + image_after_matching = cv2.merge([blue_after_transform, green_after_transform, red_after_transform]) + image_after_matching = cv2.convertScaleAbs(image_after_matching) + + return image_after_matching + + +def _standard_face_pts(): + pts = ( + np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0 + - 1.0 + ) + + return np.reshape(pts, (5, 2)) + + +def _origin_face_pts(): + pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) + + return np.reshape(pts, (5, 2)) + + +def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0): + + std_pts = _standard_face_pts() # [-1,1] + target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0 + + # print(target_pts) + + h, w, c = img.shape + if normalize == True: + landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0 + landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0 + + # print(landmark) + + affine = SimilarityTransform() + + affine.estimate(target_pts, landmark) + + return affine + + +def compute_inverse_transformation_matrix(img, landmark, normalize, target_face_scale=1.0): + + std_pts = _standard_face_pts() # [-1,1] + target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0 + + # print(target_pts) + + h, w, c = img.shape + if normalize == True: + landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0 + landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0 + + # print(landmark) + + affine = SimilarityTransform() + + affine.estimate(landmark, target_pts) + + return affine + + +def show_detection(image, box, landmark): + plt.imshow(image) + print(box[2] - box[0]) + plt.gca().add_patch( + Rectangle( + (box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor="r", facecolor="none" + ) + ) + plt.scatter(landmark[0][0], landmark[0][1]) + plt.scatter(landmark[1][0], landmark[1][1]) + plt.scatter(landmark[2][0], landmark[2][1]) + plt.scatter(landmark[3][0], landmark[3][1]) + plt.scatter(landmark[4][0], landmark[4][1]) + plt.show() + + +def affine2theta(affine, input_w, input_h, target_w, target_h): + # param = np.linalg.inv(affine) + param = affine + theta = np.zeros([2, 3]) + theta[0, 0] = param[0, 0] * input_h / target_h + theta[0, 1] = param[0, 1] * input_w / target_h + theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1 + theta[1, 0] = param[1, 0] * input_h / target_w + theta[1, 1] = param[1, 1] * input_w / target_w + theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1 + return theta + + +def blur_blending(im1, im2, mask): + + mask *= 255.0 + + kernel = np.ones((10, 10), np.uint8) + mask = cv2.erode(mask, kernel, iterations=1) + + mask = Image.fromarray(mask.astype("uint8")).convert("L") + im1 = Image.fromarray(im1.astype("uint8")) + im2 = Image.fromarray(im2.astype("uint8")) + + mask_blur = mask.filter(ImageFilter.GaussianBlur(20)) + im = Image.composite(im1, im2, mask) + + im = Image.composite(im, im2, mask_blur) + + return np.array(im) / 255.0 + + +def blur_blending_cv2(im1, im2, mask): + + mask *= 255.0 + + kernel = np.ones((9, 9), np.uint8) + mask = cv2.erode(mask, kernel, iterations=3) + + mask_blur = cv2.GaussianBlur(mask, (25, 25), 0) + mask_blur /= 255.0 + + im = im1 * mask_blur + (1 - mask_blur) * im2 + + im /= 255.0 + im = np.clip(im, 0.0, 1.0) + + return im + + +# def Poisson_blending(im1,im2,mask): + + +# Image.composite( +def Poisson_blending(im1, im2, mask): + + # mask=1-mask + mask *= 255 + kernel = np.ones((10, 10), np.uint8) + mask = cv2.erode(mask, kernel, iterations=1) + mask /= 255 + mask = 1 - mask + mask *= 255 + + mask = mask[:, :, 0] + width, height, channels = im1.shape + center = (int(height / 2), int(width / 2)) + result = cv2.seamlessClone( + im2.astype("uint8"), im1.astype("uint8"), mask.astype("uint8"), center, cv2.MIXED_CLONE + ) + + return result / 255.0 + + +def Poisson_B(im1, im2, mask, center): + + mask *= 255 + + result = cv2.seamlessClone( + im2.astype("uint8"), im1.astype("uint8"), mask.astype("uint8"), center, cv2.NORMAL_CLONE + ) + + return result / 255 + + +def seamless_clone(old_face, new_face, raw_mask): + + height, width, _ = old_face.shape + height = height // 2 + width = width // 2 + + y_indices, x_indices, _ = np.nonzero(raw_mask) + y_crop = slice(np.min(y_indices), np.max(y_indices)) + x_crop = slice(np.min(x_indices), np.max(x_indices)) + y_center = int(np.rint((np.max(y_indices) + np.min(y_indices)) / 2 + height)) + x_center = int(np.rint((np.max(x_indices) + np.min(x_indices)) / 2 + width)) + + insertion = np.rint(new_face[y_crop, x_crop] * 255.0).astype("uint8") + insertion_mask = np.rint(raw_mask[y_crop, x_crop] * 255.0).astype("uint8") + insertion_mask[insertion_mask != 0] = 255 + prior = np.rint(np.pad(old_face * 255.0, ((height, height), (width, width), (0, 0)), "constant")).astype( + "uint8" + ) + # if np.sum(insertion_mask) == 0: + n_mask = insertion_mask[1:-1, 1:-1, :] + n_mask = cv2.copyMakeBorder(n_mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, 0) + print(n_mask.shape) + x, y, w, h = cv2.boundingRect(n_mask[:, :, 0]) + if w < 4 or h < 4: + blended = prior + else: + blended = cv2.seamlessClone( + insertion, # pylint: disable=no-member + prior, + insertion_mask, + (x_center, y_center), + cv2.NORMAL_CLONE, + ) # pylint: disable=no-member + + blended = blended[height:-height, width:-width] + + return blended.astype("float32") / 255.0 + + +def get_landmark(face_landmarks, id): + part = face_landmarks.part(id) + x = part.x + y = part.y + + return (x, y) + + +def search(face_landmarks): + + x1, y1 = get_landmark(face_landmarks, 36) + x2, y2 = get_landmark(face_landmarks, 39) + x3, y3 = get_landmark(face_landmarks, 42) + x4, y4 = get_landmark(face_landmarks, 45) + + x_nose, y_nose = get_landmark(face_landmarks, 30) + + x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48) + x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54) + + x_left_eye = int((x1 + x2) / 2) + y_left_eye = int((y1 + y2) / 2) + x_right_eye = int((x3 + x4) / 2) + y_right_eye = int((y3 + y4) / 2) + + results = np.array( + [ + [x_left_eye, y_left_eye], + [x_right_eye, y_right_eye], + [x_nose, y_nose], + [x_left_mouth, y_left_mouth], + [x_right_mouth, y_right_mouth], + ] + ) + + return results + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("--origin_url", type=str, default="./", help="origin images") + parser.add_argument("--replace_url", type=str, default="./", help="restored faces") + parser.add_argument("--save_url", type=str, default="./save") + opts = parser.parse_args() + + origin_url = opts.origin_url + replace_url = opts.replace_url + save_url = opts.save_url + + if not os.path.exists(save_url): + os.makedirs(save_url) + + face_detector = dlib.get_frontal_face_detector() + landmark_locator = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") + + count = 0 + + for x in os.listdir(origin_url): + img_url = os.path.join(origin_url, x) + pil_img = Image.open(img_url).convert("RGB") + + origin_width, origin_height = pil_img.size + image = np.array(pil_img) + + start = time.time() + faces = face_detector(image) + done = time.time() + + if len(faces) == 0: + print("Warning: There is no face in %s" % (x)) + continue + + blended = image + for face_id in range(len(faces)): + + current_face = faces[face_id] + face_landmarks = landmark_locator(image, current_face) + current_fl = search(face_landmarks) + + forward_mask = np.ones_like(image).astype("uint8") + affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3) + aligned_face = warp(image, affine, output_shape=(512, 512, 3), preserve_range=True) + forward_mask = warp( + forward_mask, affine, output_shape=(512, 512, 3), order=0, preserve_range=True + ) + + affine_inverse = affine.inverse + cur_face = aligned_face + if replace_url != "": + + face_name = x[:-4] + "_" + str(face_id + 1) + ".png" + cur_url = os.path.join(replace_url, face_name) + restored_face = Image.open(cur_url).convert("RGB") + restored_face = np.array(restored_face) + cur_face = restored_face + + ## Histogram Color matching + A = cv2.cvtColor(aligned_face.astype("uint8"), cv2.COLOR_RGB2BGR) + B = cv2.cvtColor(cur_face.astype("uint8"), cv2.COLOR_RGB2BGR) + B = match_histograms(B, A) + cur_face = cv2.cvtColor(B.astype("uint8"), cv2.COLOR_BGR2RGB) + + warped_back = warp( + cur_face, + affine_inverse, + output_shape=(origin_height, origin_width, 3), + order=3, + preserve_range=True, + ) + + backward_mask = warp( + forward_mask, + affine_inverse, + output_shape=(origin_height, origin_width, 3), + order=0, + preserve_range=True, + ) ## Nearest neighbour + + blended = blur_blending_cv2(warped_back, blended, backward_mask) + blended *= 255.0 + + io.imsave(os.path.join(save_url, x), img_as_ubyte(blended / 255.0)) + + count += 1 + + if count % 1000 == 0: + print("%d have finished ..." % (count)) + diff --git a/Face_Detection/detect_all_dlib.py b/Face_Detection/detect_all_dlib.py new file mode 100644 index 0000000000000000000000000000000000000000..081b4c185e75f949dc6e2cf9ce55db78244452b6 --- /dev/null +++ b/Face_Detection/detect_all_dlib.py @@ -0,0 +1,184 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import numpy as np +import skimage.io as io + +# from FaceSDK.face_sdk import FaceDetection +# from face_sdk import FaceDetection +import matplotlib.pyplot as plt +from matplotlib.patches import Rectangle +from skimage.transform import SimilarityTransform +from skimage.transform import warp +from PIL import Image +import torch.nn.functional as F +import torchvision as tv +import torchvision.utils as vutils +import time +import cv2 +import os +from skimage import img_as_ubyte +import json +import argparse +import dlib + + +def _standard_face_pts(): + pts = ( + np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0 + - 1.0 + ) + + return np.reshape(pts, (5, 2)) + + +def _origin_face_pts(): + pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) + + return np.reshape(pts, (5, 2)) + + +def get_landmark(face_landmarks, id): + part = face_landmarks.part(id) + x = part.x + y = part.y + + return (x, y) + + +def search(face_landmarks): + + x1, y1 = get_landmark(face_landmarks, 36) + x2, y2 = get_landmark(face_landmarks, 39) + x3, y3 = get_landmark(face_landmarks, 42) + x4, y4 = get_landmark(face_landmarks, 45) + + x_nose, y_nose = get_landmark(face_landmarks, 30) + + x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48) + x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54) + + x_left_eye = int((x1 + x2) / 2) + y_left_eye = int((y1 + y2) / 2) + x_right_eye = int((x3 + x4) / 2) + y_right_eye = int((y3 + y4) / 2) + + results = np.array( + [ + [x_left_eye, y_left_eye], + [x_right_eye, y_right_eye], + [x_nose, y_nose], + [x_left_mouth, y_left_mouth], + [x_right_mouth, y_right_mouth], + ] + ) + + return results + + +def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0): + + std_pts = _standard_face_pts() # [-1,1] + target_pts = (std_pts * target_face_scale + 1) / 2 * 256.0 + + # print(target_pts) + + h, w, c = img.shape + if normalize == True: + landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0 + landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0 + + # print(landmark) + + affine = SimilarityTransform() + + affine.estimate(target_pts, landmark) + + return affine.params + + +def show_detection(image, box, landmark): + plt.imshow(image) + print(box[2] - box[0]) + plt.gca().add_patch( + Rectangle( + (box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor="r", facecolor="none" + ) + ) + plt.scatter(landmark[0][0], landmark[0][1]) + plt.scatter(landmark[1][0], landmark[1][1]) + plt.scatter(landmark[2][0], landmark[2][1]) + plt.scatter(landmark[3][0], landmark[3][1]) + plt.scatter(landmark[4][0], landmark[4][1]) + plt.show() + + +def affine2theta(affine, input_w, input_h, target_w, target_h): + # param = np.linalg.inv(affine) + param = affine + theta = np.zeros([2, 3]) + theta[0, 0] = param[0, 0] * input_h / target_h + theta[0, 1] = param[0, 1] * input_w / target_h + theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1 + theta[1, 0] = param[1, 0] * input_h / target_w + theta[1, 1] = param[1, 1] * input_w / target_w + theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1 + return theta + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("--url", type=str, default="/home/jingliao/ziyuwan/celebrities", help="input") + parser.add_argument( + "--save_url", type=str, default="/home/jingliao/ziyuwan/celebrities_detected_face_reid", help="output" + ) + opts = parser.parse_args() + + url = opts.url + save_url = opts.save_url + + ### If the origin url is None, then we don't need to reid the origin image + + os.makedirs(url, exist_ok=True) + os.makedirs(save_url, exist_ok=True) + + face_detector = dlib.get_frontal_face_detector() + landmark_locator = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") + + count = 0 + + map_id = {} + for x in os.listdir(url): + img_url = os.path.join(url, x) + pil_img = Image.open(img_url).convert("RGB") + + image = np.array(pil_img) + + start = time.time() + faces = face_detector(image) + done = time.time() + + if len(faces) == 0: + print("Warning: There is no face in %s" % (x)) + continue + + print(len(faces)) + + if len(faces) > 0: + for face_id in range(len(faces)): + current_face = faces[face_id] + face_landmarks = landmark_locator(image, current_face) + current_fl = search(face_landmarks) + + affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3) + aligned_face = warp(image, affine, output_shape=(256, 256, 3)) + img_name = x[:-4] + "_" + str(face_id + 1) + io.imsave(os.path.join(save_url, img_name + ".png"), img_as_ubyte(aligned_face)) + + count += 1 + + if count % 1000 == 0: + print("%d have finished ..." % (count)) + diff --git a/Face_Detection/detect_all_dlib_HR.py b/Face_Detection/detect_all_dlib_HR.py new file mode 100644 index 0000000000000000000000000000000000000000..f52e149bf2a9f612f4fbaca83f712da11fae0db5 --- /dev/null +++ b/Face_Detection/detect_all_dlib_HR.py @@ -0,0 +1,184 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import numpy as np +import skimage.io as io + +# from FaceSDK.face_sdk import FaceDetection +# from face_sdk import FaceDetection +import matplotlib.pyplot as plt +from matplotlib.patches import Rectangle +from skimage.transform import SimilarityTransform +from skimage.transform import warp +from PIL import Image +import torch.nn.functional as F +import torchvision as tv +import torchvision.utils as vutils +import time +import cv2 +import os +from skimage import img_as_ubyte +import json +import argparse +import dlib + + +def _standard_face_pts(): + pts = ( + np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) / 256.0 + - 1.0 + ) + + return np.reshape(pts, (5, 2)) + + +def _origin_face_pts(): + pts = np.array([196.0, 226.0, 316.0, 226.0, 256.0, 286.0, 220.0, 360.4, 292.0, 360.4], np.float32) + + return np.reshape(pts, (5, 2)) + + +def get_landmark(face_landmarks, id): + part = face_landmarks.part(id) + x = part.x + y = part.y + + return (x, y) + + +def search(face_landmarks): + + x1, y1 = get_landmark(face_landmarks, 36) + x2, y2 = get_landmark(face_landmarks, 39) + x3, y3 = get_landmark(face_landmarks, 42) + x4, y4 = get_landmark(face_landmarks, 45) + + x_nose, y_nose = get_landmark(face_landmarks, 30) + + x_left_mouth, y_left_mouth = get_landmark(face_landmarks, 48) + x_right_mouth, y_right_mouth = get_landmark(face_landmarks, 54) + + x_left_eye = int((x1 + x2) / 2) + y_left_eye = int((y1 + y2) / 2) + x_right_eye = int((x3 + x4) / 2) + y_right_eye = int((y3 + y4) / 2) + + results = np.array( + [ + [x_left_eye, y_left_eye], + [x_right_eye, y_right_eye], + [x_nose, y_nose], + [x_left_mouth, y_left_mouth], + [x_right_mouth, y_right_mouth], + ] + ) + + return results + + +def compute_transformation_matrix(img, landmark, normalize, target_face_scale=1.0): + + std_pts = _standard_face_pts() # [-1,1] + target_pts = (std_pts * target_face_scale + 1) / 2 * 512.0 + + # print(target_pts) + + h, w, c = img.shape + if normalize == True: + landmark[:, 0] = landmark[:, 0] / h * 2 - 1.0 + landmark[:, 1] = landmark[:, 1] / w * 2 - 1.0 + + # print(landmark) + + affine = SimilarityTransform() + + affine.estimate(target_pts, landmark) + + return affine.params + + +def show_detection(image, box, landmark): + plt.imshow(image) + print(box[2] - box[0]) + plt.gca().add_patch( + Rectangle( + (box[1], box[0]), box[2] - box[0], box[3] - box[1], linewidth=1, edgecolor="r", facecolor="none" + ) + ) + plt.scatter(landmark[0][0], landmark[0][1]) + plt.scatter(landmark[1][0], landmark[1][1]) + plt.scatter(landmark[2][0], landmark[2][1]) + plt.scatter(landmark[3][0], landmark[3][1]) + plt.scatter(landmark[4][0], landmark[4][1]) + plt.show() + + +def affine2theta(affine, input_w, input_h, target_w, target_h): + # param = np.linalg.inv(affine) + param = affine + theta = np.zeros([2, 3]) + theta[0, 0] = param[0, 0] * input_h / target_h + theta[0, 1] = param[0, 1] * input_w / target_h + theta[0, 2] = (2 * param[0, 2] + param[0, 0] * input_h + param[0, 1] * input_w) / target_h - 1 + theta[1, 0] = param[1, 0] * input_h / target_w + theta[1, 1] = param[1, 1] * input_w / target_w + theta[1, 2] = (2 * param[1, 2] + param[1, 0] * input_h + param[1, 1] * input_w) / target_w - 1 + return theta + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("--url", type=str, default="/home/jingliao/ziyuwan/celebrities", help="input") + parser.add_argument( + "--save_url", type=str, default="/home/jingliao/ziyuwan/celebrities_detected_face_reid", help="output" + ) + opts = parser.parse_args() + + url = opts.url + save_url = opts.save_url + + ### If the origin url is None, then we don't need to reid the origin image + + os.makedirs(url, exist_ok=True) + os.makedirs(save_url, exist_ok=True) + + face_detector = dlib.get_frontal_face_detector() + landmark_locator = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") + + count = 0 + + map_id = {} + for x in os.listdir(url): + img_url = os.path.join(url, x) + pil_img = Image.open(img_url).convert("RGB") + + image = np.array(pil_img) + + start = time.time() + faces = face_detector(image) + done = time.time() + + if len(faces) == 0: + print("Warning: There is no face in %s" % (x)) + continue + + print(len(faces)) + + if len(faces) > 0: + for face_id in range(len(faces)): + current_face = faces[face_id] + face_landmarks = landmark_locator(image, current_face) + current_fl = search(face_landmarks) + + affine = compute_transformation_matrix(image, current_fl, False, target_face_scale=1.3) + aligned_face = warp(image, affine, output_shape=(512, 512, 3)) + img_name = x[:-4] + "_" + str(face_id + 1) + io.imsave(os.path.join(save_url, img_name + ".png"), img_as_ubyte(aligned_face)) + + count += 1 + + if count % 1000 == 0: + print("%d have finished ..." % (count)) + diff --git a/Face_Enhancement/data/__init__.py b/Face_Enhancement/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..10a141788da65ea6527a4eecc9628603824d732b --- /dev/null +++ b/Face_Enhancement/data/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import importlib +import torch.utils.data +from data.base_dataset import BaseDataset +from data.face_dataset import FaceTestDataset + + +def create_dataloader(opt): + + instance = FaceTestDataset() + instance.initialize(opt) + print("dataset [%s] of size %d was created" % (type(instance).__name__, len(instance))) + dataloader = torch.utils.data.DataLoader( + instance, + batch_size=opt.batchSize, + shuffle=not opt.serial_batches, + num_workers=int(opt.nThreads), + drop_last=opt.isTrain, + ) + return dataloader diff --git a/Face_Enhancement/data/base_dataset.py b/Face_Enhancement/data/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..57595dd0bf9dd20e333bd78a6a97013b9a6d0a43 --- /dev/null +++ b/Face_Enhancement/data/base_dataset.py @@ -0,0 +1,125 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch.utils.data as data +from PIL import Image +import torchvision.transforms as transforms +import numpy as np +import random + + +class BaseDataset(data.Dataset): + def __init__(self): + super(BaseDataset, self).__init__() + + @staticmethod + def modify_commandline_options(parser, is_train): + return parser + + def initialize(self, opt): + pass + + +def get_params(opt, size): + w, h = size + new_h = h + new_w = w + if opt.preprocess_mode == "resize_and_crop": + new_h = new_w = opt.load_size + elif opt.preprocess_mode == "scale_width_and_crop": + new_w = opt.load_size + new_h = opt.load_size * h // w + elif opt.preprocess_mode == "scale_shortside_and_crop": + ss, ls = min(w, h), max(w, h) # shortside and longside + width_is_shorter = w == ss + ls = int(opt.load_size * ls / ss) + new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss) + + x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) + y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) + + flip = random.random() > 0.5 + return {"crop_pos": (x, y), "flip": flip} + + +def get_transform(opt, params, method=Image.BICUBIC, normalize=True, toTensor=True): + transform_list = [] + if "resize" in opt.preprocess_mode: + osize = [opt.load_size, opt.load_size] + transform_list.append(transforms.Resize(osize, interpolation=method)) + elif "scale_width" in opt.preprocess_mode: + transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) + elif "scale_shortside" in opt.preprocess_mode: + transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, method))) + + if "crop" in opt.preprocess_mode: + transform_list.append(transforms.Lambda(lambda img: __crop(img, params["crop_pos"], opt.crop_size))) + + if opt.preprocess_mode == "none": + base = 32 + transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) + + if opt.preprocess_mode == "fixed": + w = opt.crop_size + h = round(opt.crop_size / opt.aspect_ratio) + transform_list.append(transforms.Lambda(lambda img: __resize(img, w, h, method))) + + if opt.isTrain and not opt.no_flip: + transform_list.append(transforms.Lambda(lambda img: __flip(img, params["flip"]))) + + if toTensor: + transform_list += [transforms.ToTensor()] + + if normalize: + transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + return transforms.Compose(transform_list) + + +def normalize(): + return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + + +def __resize(img, w, h, method=Image.BICUBIC): + return img.resize((w, h), method) + + +def __make_power_2(img, base, method=Image.BICUBIC): + ow, oh = img.size + h = int(round(oh / base) * base) + w = int(round(ow / base) * base) + if (h == oh) and (w == ow): + return img + return img.resize((w, h), method) + + +def __scale_width(img, target_width, method=Image.BICUBIC): + ow, oh = img.size + if ow == target_width: + return img + w = target_width + h = int(target_width * oh / ow) + return img.resize((w, h), method) + + +def __scale_shortside(img, target_width, method=Image.BICUBIC): + ow, oh = img.size + ss, ls = min(ow, oh), max(ow, oh) # shortside and longside + width_is_shorter = ow == ss + if ss == target_width: + return img + ls = int(target_width * ls / ss) + nw, nh = (ss, ls) if width_is_shorter else (ls, ss) + return img.resize((nw, nh), method) + + +def __crop(img, pos, size): + ow, oh = img.size + x1, y1 = pos + tw = th = size + return img.crop((x1, y1, x1 + tw, y1 + th)) + + +def __flip(img, flip): + if flip: + return img.transpose(Image.FLIP_LEFT_RIGHT) + return img diff --git a/Face_Enhancement/data/custom_dataset.py b/Face_Enhancement/data/custom_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..aa0a0d79a5ca7a1816a2089b82e7ef90b28c0f43 --- /dev/null +++ b/Face_Enhancement/data/custom_dataset.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from data.pix2pix_dataset import Pix2pixDataset +from data.image_folder import make_dataset + + +class CustomDataset(Pix2pixDataset): + """ Dataset that loads images from directories + Use option --label_dir, --image_dir, --instance_dir to specify the directories. + The images in the directories are sorted in alphabetical order and paired in order. + """ + + @staticmethod + def modify_commandline_options(parser, is_train): + parser = Pix2pixDataset.modify_commandline_options(parser, is_train) + parser.set_defaults(preprocess_mode="resize_and_crop") + load_size = 286 if is_train else 256 + parser.set_defaults(load_size=load_size) + parser.set_defaults(crop_size=256) + parser.set_defaults(display_winsize=256) + parser.set_defaults(label_nc=13) + parser.set_defaults(contain_dontcare_label=False) + + parser.add_argument( + "--label_dir", type=str, required=True, help="path to the directory that contains label images" + ) + parser.add_argument( + "--image_dir", type=str, required=True, help="path to the directory that contains photo images" + ) + parser.add_argument( + "--instance_dir", + type=str, + default="", + help="path to the directory that contains instance maps. Leave black if not exists", + ) + return parser + + def get_paths(self, opt): + label_dir = opt.label_dir + label_paths = make_dataset(label_dir, recursive=False, read_cache=True) + + image_dir = opt.image_dir + image_paths = make_dataset(image_dir, recursive=False, read_cache=True) + + if len(opt.instance_dir) > 0: + instance_dir = opt.instance_dir + instance_paths = make_dataset(instance_dir, recursive=False, read_cache=True) + else: + instance_paths = [] + + assert len(label_paths) == len( + image_paths + ), "The #images in %s and %s do not match. Is there something wrong?" + + return label_paths, image_paths, instance_paths diff --git a/Face_Enhancement/data/face_dataset.py b/Face_Enhancement/data/face_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..75b49570a0d4b066c5a38ae86b24ddd024b00be9 --- /dev/null +++ b/Face_Enhancement/data/face_dataset.py @@ -0,0 +1,102 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from data.base_dataset import BaseDataset, get_params, get_transform +from PIL import Image +import util.util as util +import os +import torch + + +class FaceTestDataset(BaseDataset): + @staticmethod + def modify_commandline_options(parser, is_train): + parser.add_argument( + "--no_pairing_check", + action="store_true", + help="If specified, skip sanity check of correct label-image file pairing", + ) + # parser.set_defaults(contain_dontcare_label=False) + # parser.set_defaults(no_instance=True) + return parser + + def initialize(self, opt): + self.opt = opt + + image_path = os.path.join(opt.dataroot, opt.old_face_folder) + label_path = os.path.join(opt.dataroot, opt.old_face_label_folder) + + image_list = os.listdir(image_path) + image_list = sorted(image_list) + # image_list=image_list[:opt.max_dataset_size] + + self.label_paths = label_path ## Just the root dir + self.image_paths = image_list ## All the image name + + self.parts = [ + "skin", + "hair", + "l_brow", + "r_brow", + "l_eye", + "r_eye", + "eye_g", + "l_ear", + "r_ear", + "ear_r", + "nose", + "mouth", + "u_lip", + "l_lip", + "neck", + "neck_l", + "cloth", + "hat", + ] + + size = len(self.image_paths) + self.dataset_size = size + + def __getitem__(self, index): + + params = get_params(self.opt, (-1, -1)) + image_name = self.image_paths[index] + image_path = os.path.join(self.opt.dataroot, self.opt.old_face_folder, image_name) + image = Image.open(image_path) + image = image.convert("RGB") + + transform_image = get_transform(self.opt, params) + image_tensor = transform_image(image) + + img_name = image_name[:-4] + transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) + full_label = [] + + cnt = 0 + + for each_part in self.parts: + part_name = img_name + "_" + each_part + ".png" + part_url = os.path.join(self.label_paths, part_name) + + if os.path.exists(part_url): + label = Image.open(part_url).convert("RGB") + label_tensor = transform_label(label) ## 3 channels and pixel [0,1] + full_label.append(label_tensor[0]) + else: + current_part = torch.zeros((self.opt.load_size, self.opt.load_size)) + full_label.append(current_part) + cnt += 1 + + full_label_tensor = torch.stack(full_label, 0) + + input_dict = { + "label": full_label_tensor, + "image": image_tensor, + "path": image_path, + } + + return input_dict + + def __len__(self): + return self.dataset_size + diff --git a/Face_Enhancement/data/image_folder.py b/Face_Enhancement/data/image_folder.py new file mode 100644 index 0000000000000000000000000000000000000000..7281eb23df59a7337732d5b4622977137fefdbd4 --- /dev/null +++ b/Face_Enhancement/data/image_folder.py @@ -0,0 +1,101 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch.utils.data as data +from PIL import Image +import os + +IMG_EXTENSIONS = [ + ".jpg", + ".JPG", + ".jpeg", + ".JPEG", + ".png", + ".PNG", + ".ppm", + ".PPM", + ".bmp", + ".BMP", + ".tiff", + ".webp", +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def make_dataset_rec(dir, images): + assert os.path.isdir(dir), "%s is not a valid directory" % dir + + for root, dnames, fnames in sorted(os.walk(dir, followlinks=True)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(root, fname) + images.append(path) + + +def make_dataset(dir, recursive=False, read_cache=False, write_cache=False): + images = [] + + if read_cache: + possible_filelist = os.path.join(dir, "files.list") + if os.path.isfile(possible_filelist): + with open(possible_filelist, "r") as f: + images = f.read().splitlines() + return images + + if recursive: + make_dataset_rec(dir, images) + else: + assert os.path.isdir(dir) or os.path.islink(dir), "%s is not a valid directory" % dir + + for root, dnames, fnames in sorted(os.walk(dir)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(root, fname) + images.append(path) + + if write_cache: + filelist_cache = os.path.join(dir, "files.list") + with open(filelist_cache, "w") as f: + for path in images: + f.write("%s\n" % path) + print("wrote filelist cache at %s" % filelist_cache) + + return images + + +def default_loader(path): + return Image.open(path).convert("RGB") + + +class ImageFolder(data.Dataset): + def __init__(self, root, transform=None, return_paths=False, loader=default_loader): + imgs = make_dataset(root) + if len(imgs) == 0: + raise ( + RuntimeError( + "Found 0 images in: " + root + "\n" + "Supported image extensions are: " + ",".join(IMG_EXTENSIONS) + ) + ) + + self.root = root + self.imgs = imgs + self.transform = transform + self.return_paths = return_paths + self.loader = loader + + def __getitem__(self, index): + path = self.imgs[index] + img = self.loader(path) + if self.transform is not None: + img = self.transform(img) + if self.return_paths: + return img, path + else: + return img + + def __len__(self): + return len(self.imgs) diff --git a/Face_Enhancement/data/pix2pix_dataset.py b/Face_Enhancement/data/pix2pix_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..511bd83f55be80ae50bb09c4f6c11fafd4cf8214 --- /dev/null +++ b/Face_Enhancement/data/pix2pix_dataset.py @@ -0,0 +1,108 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from data.base_dataset import BaseDataset, get_params, get_transform +from PIL import Image +import util.util as util +import os + + +class Pix2pixDataset(BaseDataset): + @staticmethod + def modify_commandline_options(parser, is_train): + parser.add_argument( + "--no_pairing_check", + action="store_true", + help="If specified, skip sanity check of correct label-image file pairing", + ) + return parser + + def initialize(self, opt): + self.opt = opt + + label_paths, image_paths, instance_paths = self.get_paths(opt) + + util.natural_sort(label_paths) + util.natural_sort(image_paths) + if not opt.no_instance: + util.natural_sort(instance_paths) + + label_paths = label_paths[: opt.max_dataset_size] + image_paths = image_paths[: opt.max_dataset_size] + instance_paths = instance_paths[: opt.max_dataset_size] + + if not opt.no_pairing_check: + for path1, path2 in zip(label_paths, image_paths): + assert self.paths_match(path1, path2), ( + "The label-image pair (%s, %s) do not look like the right pair because the filenames are quite different. Are you sure about the pairing? Please see data/pix2pix_dataset.py to see what is going on, and use --no_pairing_check to bypass this." + % (path1, path2) + ) + + self.label_paths = label_paths + self.image_paths = image_paths + self.instance_paths = instance_paths + + size = len(self.label_paths) + self.dataset_size = size + + def get_paths(self, opt): + label_paths = [] + image_paths = [] + instance_paths = [] + assert False, "A subclass of Pix2pixDataset must override self.get_paths(self, opt)" + return label_paths, image_paths, instance_paths + + def paths_match(self, path1, path2): + filename1_without_ext = os.path.splitext(os.path.basename(path1))[0] + filename2_without_ext = os.path.splitext(os.path.basename(path2))[0] + return filename1_without_ext == filename2_without_ext + + def __getitem__(self, index): + # Label Image + label_path = self.label_paths[index] + label = Image.open(label_path) + params = get_params(self.opt, label.size) + transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) + label_tensor = transform_label(label) * 255.0 + label_tensor[label_tensor == 255] = self.opt.label_nc # 'unknown' is opt.label_nc + + # input image (real images) + image_path = self.image_paths[index] + assert self.paths_match( + label_path, image_path + ), "The label_path %s and image_path %s don't match." % (label_path, image_path) + image = Image.open(image_path) + image = image.convert("RGB") + + transform_image = get_transform(self.opt, params) + image_tensor = transform_image(image) + + # if using instance maps + if self.opt.no_instance: + instance_tensor = 0 + else: + instance_path = self.instance_paths[index] + instance = Image.open(instance_path) + if instance.mode == "L": + instance_tensor = transform_label(instance) * 255 + instance_tensor = instance_tensor.long() + else: + instance_tensor = transform_label(instance) + + input_dict = { + "label": label_tensor, + "instance": instance_tensor, + "image": image_tensor, + "path": image_path, + } + + # Give subclasses a chance to modify the final output + self.postprocess(input_dict) + + return input_dict + + def postprocess(self, input_dict): + return input_dict + + def __len__(self): + return self.dataset_size diff --git a/Face_Enhancement/models/__init__.py b/Face_Enhancement/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ad491ed6270caf03e6d1c34e56163f5ee8fbf2bc --- /dev/null +++ b/Face_Enhancement/models/__init__.py @@ -0,0 +1,44 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import importlib +import torch + + +def find_model_using_name(model_name): + # Given the option --model [modelname], + # the file "models/modelname_model.py" + # will be imported. + model_filename = "models." + model_name + "_model" + modellib = importlib.import_module(model_filename) + + # In the file, the class called ModelNameModel() will + # be instantiated. It has to be a subclass of torch.nn.Module, + # and it is case-insensitive. + model = None + target_model_name = model_name.replace("_", "") + "model" + for name, cls in modellib.__dict__.items(): + if name.lower() == target_model_name.lower() and issubclass(cls, torch.nn.Module): + model = cls + + if model is None: + print( + "In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase." + % (model_filename, target_model_name) + ) + exit(0) + + return model + + +def get_option_setter(model_name): + model_class = find_model_using_name(model_name) + return model_class.modify_commandline_options + + +def create_model(opt): + model = find_model_using_name(opt.model) + instance = model(opt) + print("model [%s] was created" % (type(instance).__name__)) + + return instance diff --git a/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/LICENSE b/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..4c39939e7e3aa940d405030335ec0e6ff2f2a1ee --- /dev/null +++ b/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2018 Jiayuan MAO + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/README.md b/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/README.md new file mode 100644 index 0000000000000000000000000000000000000000..779983436c9727dd0d6301a1c857f2360245b51d --- /dev/null +++ b/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/README.md @@ -0,0 +1,118 @@ +# Synchronized-BatchNorm-PyTorch + +**IMPORTANT: Please read the "Implementation details and highlights" section before use.** + +Synchronized Batch Normalization implementation in PyTorch. + +This module differs from the built-in PyTorch BatchNorm as the mean and +standard-deviation are reduced across all devices during training. + +For example, when one uses `nn.DataParallel` to wrap the network during +training, PyTorch's implementation normalize the tensor on each device using +the statistics only on that device, which accelerated the computation and +is also easy to implement, but the statistics might be inaccurate. +Instead, in this synchronized version, the statistics will be computed +over all training samples distributed on multiple devices. + +Note that, for one-GPU or CPU-only case, this module behaves exactly same +as the built-in PyTorch implementation. + +This module is currently only a prototype version for research usages. As mentioned below, +it has its limitations and may even suffer from some design problems. If you have any +questions or suggestions, please feel free to +[open an issue](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues) or +[submit a pull request](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues). + +## Why Synchronized BatchNorm? + +Although the typical implementation of BatchNorm working on multiple devices (GPUs) +is fast (with no communication overhead), it inevitably reduces the size of batch size, +which potentially degenerates the performance. This is not a significant issue in some +standard vision tasks such as ImageNet classification (as the batch size per device +is usually large enough to obtain good statistics). However, it will hurt the performance +in some tasks that the batch size is usually very small (e.g., 1 per GPU). + +For example, the importance of synchronized batch normalization in object detection has been recently proved with a +an extensive analysis in the paper [MegDet: A Large Mini-Batch Object Detector](https://arxiv.org/abs/1711.07240). + +## Usage + +To use the Synchronized Batch Normalization, we add a data parallel replication callback. This introduces a slight +difference with typical usage of the `nn.DataParallel`. + +Use it with a provided, customized data parallel wrapper: + +```python +from sync_batchnorm import SynchronizedBatchNorm1d, DataParallelWithCallback + +sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) +sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) +``` + +Or, if you are using a customized data parallel module, you can use this library as a monkey patching. + +```python +from torch.nn import DataParallel # or your customized DataParallel module +from sync_batchnorm import SynchronizedBatchNorm1d, patch_replication_callback + +sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) +sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) +patch_replication_callback(sync_bn) # monkey-patching +``` + +You can use `convert_model` to convert your model to use Synchronized BatchNorm easily. + +```python +import torch.nn as nn +from torchvision import models +from sync_batchnorm import convert_model +# m is a standard pytorch model +m = models.resnet18(True) +m = nn.DataParallel(m) +# after convert, m is using SyncBN +m = convert_model(m) +``` + +See also `tests/test_sync_batchnorm.py` for numeric result comparison. + +## Implementation details and highlights + +If you are interested in how batch statistics are reduced and broadcasted among multiple devices, please take a look +at the code with detailed comments. Here we only emphasize some highlights of the implementation: + +- This implementation is in pure-python. No C++ extra extension libs. +- Easy to use as demonstrated above. +- It uses unbiased variance to update the moving average, and use `sqrt(max(var, eps))` instead of `sqrt(var + eps)`. +- The implementation requires that each module on different devices should invoke the `batchnorm` for exactly SAME +amount of times in each forward pass. For example, you can not only call `batchnorm` on GPU0 but not on GPU1. The `#i +(i = 1, 2, 3, ...)` calls of the `batchnorm` on each device will be viewed as a whole and the statistics will be reduced. +This is tricky but is a good way to handle PyTorch's dynamic computation graph. Although sounds complicated, this +will usually not be the issue for most of the models. + +## Known issues + +#### Runtime error on backward pass. + +Due to a [PyTorch Bug](https://github.com/pytorch/pytorch/issues/3883), using old PyTorch libraries will trigger an `RuntimeError` with messages like: + +``` +Assertion `pos >= 0 && pos < buffer.size()` failed. +``` + +This has already been solved in the newest PyTorch repo, which, unfortunately, has not been pushed to the official and anaconda binary release. Thus, you are required to build the PyTorch package from the source according to the + instructions [here](https://github.com/pytorch/pytorch#from-source). + +#### Numeric error. + +Because this library does not fuse the normalization and statistics operations in C++ (nor CUDA), it is less +numerically stable compared to the original PyTorch implementation. Detailed analysis can be found in +`tests/test_sync_batchnorm.py`. + +## Authors and License: + +Copyright (c) 2018-, [Jiayuan Mao](https://vccy.xyz). + +**Contributors**: [Tete Xiao](https://tetexiao.com), [DTennant](https://github.com/DTennant). + +Distributed under **MIT License** (See LICENSE) + diff --git a/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/__init__.py b/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6d9b36c74b1808b56ded68cf080a689db7e0ee4e --- /dev/null +++ b/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# File : __init__.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +from .batchnorm import set_sbn_eps_mode +from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d +from .batchnorm import patch_sync_batchnorm, convert_model +from .replicate import DataParallelWithCallback, patch_replication_callback diff --git a/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/batchnorm.py b/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..bf8d7a7325b474771a11a137053971fd40426079 --- /dev/null +++ b/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/batchnorm.py @@ -0,0 +1,412 @@ +# -*- coding: utf-8 -*- +# File : batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import collections +import contextlib + +import torch +import torch.nn.functional as F + +from torch.nn.modules.batchnorm import _BatchNorm + +try: + from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast +except ImportError: + ReduceAddCoalesced = Broadcast = None + +try: + from jactorch.parallel.comm import SyncMaster + from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback +except ImportError: + from .comm import SyncMaster + from .replicate import DataParallelWithCallback + +__all__ = [ + 'set_sbn_eps_mode', + 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d', + 'patch_sync_batchnorm', 'convert_model' +] + + +SBN_EPS_MODE = 'clamp' + + +def set_sbn_eps_mode(mode): + global SBN_EPS_MODE + assert mode in ('clamp', 'plus') + SBN_EPS_MODE = mode + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dimensions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) + + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): + assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.' + + super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, + track_running_stats=track_running_stats) + + if not self.track_running_stats: + import warnings + warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.') + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + def forward(self, input): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + + # Resize the input to (B, C, -1). + input_shape = input.size() + assert input.size(1) == self.num_features, 'Channel size mismatch: got {}, expect {}.'.format(input.size(1), self.num_features) + input = input.view(input.size(0), self.num_features, -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input ** 2) + + # Reduce-and-broadcast the statistics. + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + + # Compute the output. + if self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + + # Always using same "device order" makes the ReduceAdd operation faster. + # Thanks to:: Tete Xiao (http://tetexiao.com/) + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) + + return outputs + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + if hasattr(torch, 'no_grad'): + with torch.no_grad(): + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + else: + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + + if SBN_EPS_MODE == 'clamp': + return mean, bias_var.clamp(self.eps) ** -0.5 + elif SBN_EPS_MODE == 'plus': + return mean, (bias_var + self.eps) ** -0.5 + else: + raise ValueError('Unknown EPS mode: {}.'.format(SBN_EPS_MODE)) + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + + +@contextlib.contextmanager +def patch_sync_batchnorm(): + import torch.nn as nn + + backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d + + nn.BatchNorm1d = SynchronizedBatchNorm1d + nn.BatchNorm2d = SynchronizedBatchNorm2d + nn.BatchNorm3d = SynchronizedBatchNorm3d + + yield + + nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup + + +def convert_model(module): + """Traverse the input module and its child recursively + and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d + to SynchronizedBatchNorm*N*d + + Args: + module: the input module needs to be convert to SyncBN model + + Examples: + >>> import torch.nn as nn + >>> import torchvision + >>> # m is a standard pytorch model + >>> m = torchvision.models.resnet18(True) + >>> m = nn.DataParallel(m) + >>> # after convert, m is using SyncBN + >>> m = convert_model(m) + """ + if isinstance(module, torch.nn.DataParallel): + mod = module.module + mod = convert_model(mod) + mod = DataParallelWithCallback(mod, device_ids=module.device_ids) + return mod + + mod = module + for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d, + torch.nn.modules.batchnorm.BatchNorm2d, + torch.nn.modules.batchnorm.BatchNorm3d], + [SynchronizedBatchNorm1d, + SynchronizedBatchNorm2d, + SynchronizedBatchNorm3d]): + if isinstance(module, pth_module): + mod = sync_module(module.num_features, module.eps, module.momentum, module.affine) + mod.running_mean = module.running_mean + mod.running_var = module.running_var + if module.affine: + mod.weight.data = module.weight.data.clone().detach() + mod.bias.data = module.bias.data.clone().detach() + + for name, child in module.named_children(): + mod.add_module(name, convert_model(child)) + + return mod diff --git a/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/batchnorm_reimpl.py b/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/batchnorm_reimpl.py new file mode 100644 index 0000000000000000000000000000000000000000..18145c3353e13d482c492ae46df91a537669fca0 --- /dev/null +++ b/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/batchnorm_reimpl.py @@ -0,0 +1,74 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# File : batchnorm_reimpl.py +# Author : acgtyrant +# Date : 11/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import torch +import torch.nn as nn +import torch.nn.init as init + +__all__ = ['BatchNorm2dReimpl'] + + +class BatchNorm2dReimpl(nn.Module): + """ + A re-implementation of batch normalization, used for testing the numerical + stability. + + Author: acgtyrant + See also: + https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 + """ + def __init__(self, num_features, eps=1e-5, momentum=0.1): + super().__init__() + + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.weight = nn.Parameter(torch.empty(num_features)) + self.bias = nn.Parameter(torch.empty(num_features)) + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_running_stats(self): + self.running_mean.zero_() + self.running_var.fill_(1) + + def reset_parameters(self): + self.reset_running_stats() + init.uniform_(self.weight) + init.zeros_(self.bias) + + def forward(self, input_): + batchsize, channels, height, width = input_.size() + numel = batchsize * height * width + input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) + sum_ = input_.sum(1) + sum_of_square = input_.pow(2).sum(1) + mean = sum_ / numel + sumvar = sum_of_square - sum_ * mean + + self.running_mean = ( + (1 - self.momentum) * self.running_mean + + self.momentum * mean.detach() + ) + unbias_var = sumvar / (numel - 1) + self.running_var = ( + (1 - self.momentum) * self.running_var + + self.momentum * unbias_var.detach() + ) + + bias_var = sumvar / numel + inv_std = 1 / (bias_var + self.eps).pow(0.5) + output = ( + (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * + self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) + + return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() + diff --git a/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/comm.py b/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..922f8c4a3adaa9b32fdcaef09583be03b0d7eb2b --- /dev/null +++ b/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/comm.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +# File : comm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def __getstate__(self): + return {'master_callback': self._master_callback} + + def __setstate__(self, state): + self.__init__(state['master_callback']) + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/replicate.py b/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..b71c7b8ed51a1d6c55b1f753bdd8d90bad79bd06 --- /dev/null +++ b/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/replicate.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# File : replicate.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import functools + +from torch.nn.parallel.data_parallel import DataParallel + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate diff --git a/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/unittest.py b/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/unittest.py new file mode 100644 index 0000000000000000000000000000000000000000..998223a0e0242dc4a5b2fcd74af79dc7232794da --- /dev/null +++ b/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/sync_batchnorm/unittest.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# File : unittest.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import unittest +import torch + + +class TorchTestCase(unittest.TestCase): + def assertTensorClose(self, x, y): + adiff = float((x - y).abs().max()) + if (y == 0).all(): + rdiff = 'NaN' + else: + rdiff = float((adiff / y).abs().max()) + + message = ( + 'Tensor close check failed\n' + 'adiff={}\n' + 'rdiff={}\n' + ).format(adiff, rdiff) + self.assertTrue(torch.allclose(x, y, atol=1e-5, rtol=1e-3), message) + diff --git a/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/tests/test_numeric_batchnorm.py b/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/tests/test_numeric_batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..63661389782806ea2182c049448df5d05fc6d2f1 --- /dev/null +++ b/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/tests/test_numeric_batchnorm.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# File : test_numeric_batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. + +import unittest + +import torch +import torch.nn as nn +from torch.autograd import Variable + +from sync_batchnorm.unittest import TorchTestCase + + +def handy_var(a, unbias=True): + n = a.size(0) + asum = a.sum(dim=0) + as_sum = (a ** 2).sum(dim=0) # a square sum + sumvar = as_sum - asum * asum / n + if unbias: + return sumvar / (n - 1) + else: + return sumvar / n + + +class NumericTestCase(TorchTestCase): + def testNumericBatchNorm(self): + a = torch.rand(16, 10) + bn = nn.BatchNorm1d(10, momentum=1, eps=1e-5, affine=False) + bn.train() + + a_var1 = Variable(a, requires_grad=True) + b_var1 = bn(a_var1) + loss1 = b_var1.sum() + loss1.backward() + + a_var2 = Variable(a, requires_grad=True) + a_mean2 = a_var2.mean(dim=0, keepdim=True) + a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5)) + # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5) + b_var2 = (a_var2 - a_mean2) / a_std2 + loss2 = b_var2.sum() + loss2.backward() + + self.assertTensorClose(bn.running_mean, a.mean(dim=0)) + self.assertTensorClose(bn.running_var, handy_var(a)) + self.assertTensorClose(a_var1.data, a_var2.data) + self.assertTensorClose(b_var1.data, b_var2.data) + self.assertTensorClose(a_var1.grad, a_var2.grad) + + +if __name__ == '__main__': + unittest.main() diff --git a/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/tests/test_numeric_batchnorm_v2.py b/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/tests/test_numeric_batchnorm_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4538ae3c50b4c457a9fa19bf22b5b1a7b666ee --- /dev/null +++ b/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/tests/test_numeric_batchnorm_v2.py @@ -0,0 +1,62 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# File : test_numeric_batchnorm_v2.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 11/01/2018 +# +# Distributed under terms of the MIT license. + +""" +Test the numerical implementation of batch normalization. + +Author: acgtyrant. +See also: https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 +""" + +import unittest + +import torch +import torch.nn as nn +import torch.optim as optim + +from sync_batchnorm.unittest import TorchTestCase +from sync_batchnorm.batchnorm_reimpl import BatchNorm2dReimpl + + +class NumericTestCasev2(TorchTestCase): + def testNumericBatchNorm(self): + CHANNELS = 16 + batchnorm1 = nn.BatchNorm2d(CHANNELS, momentum=1) + optimizer1 = optim.SGD(batchnorm1.parameters(), lr=0.01) + + batchnorm2 = BatchNorm2dReimpl(CHANNELS, momentum=1) + batchnorm2.weight.data.copy_(batchnorm1.weight.data) + batchnorm2.bias.data.copy_(batchnorm1.bias.data) + optimizer2 = optim.SGD(batchnorm2.parameters(), lr=0.01) + + for _ in range(100): + input_ = torch.rand(16, CHANNELS, 16, 16) + + input1 = input_.clone().requires_grad_(True) + output1 = batchnorm1(input1) + output1.sum().backward() + optimizer1.step() + + input2 = input_.clone().requires_grad_(True) + output2 = batchnorm2(input2) + output2.sum().backward() + optimizer2.step() + + self.assertTensorClose(input1, input2) + self.assertTensorClose(output1, output2) + self.assertTensorClose(input1.grad, input2.grad) + self.assertTensorClose(batchnorm1.weight.grad, batchnorm2.weight.grad) + self.assertTensorClose(batchnorm1.bias.grad, batchnorm2.bias.grad) + self.assertTensorClose(batchnorm1.running_mean, batchnorm2.running_mean) + self.assertTensorClose(batchnorm2.running_mean, batchnorm2.running_mean) + + +if __name__ == '__main__': + unittest.main() + diff --git a/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/tests/test_sync_batchnorm.py b/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/tests/test_sync_batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..1f7b6c64c06fc26348489cd15669501a2098c82f --- /dev/null +++ b/Face_Enhancement/models/networks/Synchronized-BatchNorm-PyTorch/tests/test_sync_batchnorm.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- +# File : test_sync_batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. + +import unittest + +import torch +import torch.nn as nn +from torch.autograd import Variable + +from sync_batchnorm import set_sbn_eps_mode +from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback +from sync_batchnorm.unittest import TorchTestCase + +set_sbn_eps_mode('plus') + + +def handy_var(a, unbias=True): + n = a.size(0) + asum = a.sum(dim=0) + as_sum = (a ** 2).sum(dim=0) # a square sum + sumvar = as_sum - asum * asum / n + if unbias: + return sumvar / (n - 1) + else: + return sumvar / n + + +def _find_bn(module): + for m in module.modules(): + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)): + return m + + +class SyncTestCase(TorchTestCase): + def _syncParameters(self, bn1, bn2): + bn1.reset_parameters() + bn2.reset_parameters() + if bn1.affine and bn2.affine: + bn2.weight.data.copy_(bn1.weight.data) + bn2.bias.data.copy_(bn1.bias.data) + + def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False): + """Check the forward and backward for the customized batch normalization.""" + bn1.train(mode=is_train) + bn2.train(mode=is_train) + + if cuda: + input = input.cuda() + + self._syncParameters(_find_bn(bn1), _find_bn(bn2)) + + input1 = Variable(input, requires_grad=True) + output1 = bn1(input1) + output1.sum().backward() + input2 = Variable(input, requires_grad=True) + output2 = bn2(input2) + output2.sum().backward() + + self.assertTensorClose(input1.data, input2.data) + self.assertTensorClose(output1.data, output2.data) + self.assertTensorClose(input1.grad, input2.grad) + self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean) + self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var) + + def testSyncBatchNormNormalTrain(self): + bn = nn.BatchNorm1d(10) + sync_bn = SynchronizedBatchNorm1d(10) + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True) + + def testSyncBatchNormNormalEval(self): + bn = nn.BatchNorm1d(10) + sync_bn = SynchronizedBatchNorm1d(10) + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False) + + def testSyncBatchNormSyncTrain(self): + bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) + sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + + bn.cuda() + sync_bn.cuda() + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True) + + def testSyncBatchNormSyncEval(self): + bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) + sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + + bn.cuda() + sync_bn.cuda() + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True) + + def testSyncBatchNorm2DSyncTrain(self): + bn = nn.BatchNorm2d(10) + sync_bn = SynchronizedBatchNorm2d(10) + sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + + bn.cuda() + sync_bn.cuda() + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True) + + +if __name__ == '__main__': + unittest.main() diff --git a/Face_Enhancement/models/networks/__init__.py b/Face_Enhancement/models/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..33851d1c220373f92f670cda0cde03b0fe17300f --- /dev/null +++ b/Face_Enhancement/models/networks/__init__.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +from models.networks.base_network import BaseNetwork +from models.networks.generator import * +from models.networks.encoder import * +import util.util as util + + +def find_network_using_name(target_network_name, filename): + target_class_name = target_network_name + filename + module_name = "models.networks." + filename + network = util.find_class_in_module(target_class_name, module_name) + + assert issubclass(network, BaseNetwork), "Class %s should be a subclass of BaseNetwork" % network + + return network + + +def modify_commandline_options(parser, is_train): + opt, _ = parser.parse_known_args() + + netG_cls = find_network_using_name(opt.netG, "generator") + parser = netG_cls.modify_commandline_options(parser, is_train) + if is_train: + netD_cls = find_network_using_name(opt.netD, "discriminator") + parser = netD_cls.modify_commandline_options(parser, is_train) + netE_cls = find_network_using_name("conv", "encoder") + parser = netE_cls.modify_commandline_options(parser, is_train) + + return parser + + +def create_network(cls, opt): + net = cls(opt) + net.print_network() + if len(opt.gpu_ids) > 0: + assert torch.cuda.is_available() + net.cuda() + net.init_weights(opt.init_type, opt.init_variance) + return net + + +def define_G(opt): + netG_cls = find_network_using_name(opt.netG, "generator") + return create_network(netG_cls, opt) + + +def define_D(opt): + netD_cls = find_network_using_name(opt.netD, "discriminator") + return create_network(netD_cls, opt) + + +def define_E(opt): + # there exists only one encoder type + netE_cls = find_network_using_name("conv", "encoder") + return create_network(netE_cls, opt) diff --git a/Face_Enhancement/models/networks/architecture.py b/Face_Enhancement/models/networks/architecture.py new file mode 100644 index 0000000000000000000000000000000000000000..91eb91c8c9fd6500d191456bb3dd8b39d491bb5a --- /dev/null +++ b/Face_Enhancement/models/networks/architecture.py @@ -0,0 +1,173 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +import torch.nn.utils.spectral_norm as spectral_norm +from models.networks.normalization import SPADE + + +# ResNet block that uses SPADE. +# It differs from the ResNet block of pix2pixHD in that +# it takes in the segmentation map as input, learns the skip connection if necessary, +# and applies normalization first and then convolution. +# This architecture seemed like a standard architecture for unconditional or +# class-conditional GAN architecture using residual block. +# The code was inspired from https://github.com/LMescheder/GAN_stability. +class SPADEResnetBlock(nn.Module): + def __init__(self, fin, fout, opt): + super().__init__() + # Attributes + self.learned_shortcut = fin != fout + fmiddle = min(fin, fout) + + self.opt = opt + # create conv layers + self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) + self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) + if self.learned_shortcut: + self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) + + # apply spectral norm if specified + if "spectral" in opt.norm_G: + self.conv_0 = spectral_norm(self.conv_0) + self.conv_1 = spectral_norm(self.conv_1) + if self.learned_shortcut: + self.conv_s = spectral_norm(self.conv_s) + + # define normalization layers + spade_config_str = opt.norm_G.replace("spectral", "") + self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc, opt) + self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc, opt) + if self.learned_shortcut: + self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc, opt) + + # note the resnet block with SPADE also takes in |seg|, + # the semantic segmentation map as input + def forward(self, x, seg, degraded_image): + x_s = self.shortcut(x, seg, degraded_image) + + dx = self.conv_0(self.actvn(self.norm_0(x, seg, degraded_image))) + dx = self.conv_1(self.actvn(self.norm_1(dx, seg, degraded_image))) + + out = x_s + dx + + return out + + def shortcut(self, x, seg, degraded_image): + if self.learned_shortcut: + x_s = self.conv_s(self.norm_s(x, seg, degraded_image)) + else: + x_s = x + return x_s + + def actvn(self, x): + return F.leaky_relu(x, 2e-1) + + +# ResNet block used in pix2pixHD +# We keep the same architecture as pix2pixHD. +class ResnetBlock(nn.Module): + def __init__(self, dim, norm_layer, activation=nn.ReLU(False), kernel_size=3): + super().__init__() + + pw = (kernel_size - 1) // 2 + self.conv_block = nn.Sequential( + nn.ReflectionPad2d(pw), + norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)), + activation, + nn.ReflectionPad2d(pw), + norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)), + ) + + def forward(self, x): + y = self.conv_block(x) + out = x + y + return out + + +# VGG architecter, used for the perceptual loss using a pretrained VGG network +class VGG19(torch.nn.Module): + def __init__(self, requires_grad=False): + super().__init__() + vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(2): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(2, 7): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(7, 12): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 21): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(21, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h_relu1 = self.slice1(X) + h_relu2 = self.slice2(h_relu1) + h_relu3 = self.slice3(h_relu2) + h_relu4 = self.slice4(h_relu3) + h_relu5 = self.slice5(h_relu4) + out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] + return out + + +class SPADEResnetBlock_non_spade(nn.Module): + def __init__(self, fin, fout, opt): + super().__init__() + # Attributes + self.learned_shortcut = fin != fout + fmiddle = min(fin, fout) + + self.opt = opt + # create conv layers + self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) + self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) + if self.learned_shortcut: + self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) + + # apply spectral norm if specified + if "spectral" in opt.norm_G: + self.conv_0 = spectral_norm(self.conv_0) + self.conv_1 = spectral_norm(self.conv_1) + if self.learned_shortcut: + self.conv_s = spectral_norm(self.conv_s) + + # define normalization layers + spade_config_str = opt.norm_G.replace("spectral", "") + self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc, opt) + self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc, opt) + if self.learned_shortcut: + self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc, opt) + + # note the resnet block with SPADE also takes in |seg|, + # the semantic segmentation map as input + def forward(self, x, seg, degraded_image): + x_s = self.shortcut(x, seg, degraded_image) + + dx = self.conv_0(self.actvn(x)) + dx = self.conv_1(self.actvn(dx)) + + out = x_s + dx + + return out + + def shortcut(self, x, seg, degraded_image): + if self.learned_shortcut: + x_s = self.conv_s(x) + else: + x_s = x + return x_s + + def actvn(self, x): + return F.leaky_relu(x, 2e-1) diff --git a/Face_Enhancement/models/networks/base_network.py b/Face_Enhancement/models/networks/base_network.py new file mode 100644 index 0000000000000000000000000000000000000000..bc33f0e70082bf4be536fe5cf576f40c48800159 --- /dev/null +++ b/Face_Enhancement/models/networks/base_network.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch.nn as nn +from torch.nn import init + + +class BaseNetwork(nn.Module): + def __init__(self): + super(BaseNetwork, self).__init__() + + @staticmethod + def modify_commandline_options(parser, is_train): + return parser + + def print_network(self): + if isinstance(self, list): + self = self[0] + num_params = 0 + for param in self.parameters(): + num_params += param.numel() + print( + "Network [%s] was created. Total number of parameters: %.1f million. " + "To see the architecture, do print(network)." % (type(self).__name__, num_params / 1000000) + ) + + def init_weights(self, init_type="normal", gain=0.02): + def init_func(m): + classname = m.__class__.__name__ + if classname.find("BatchNorm2d") != -1: + if hasattr(m, "weight") and m.weight is not None: + init.normal_(m.weight.data, 1.0, gain) + if hasattr(m, "bias") and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1): + if init_type == "normal": + init.normal_(m.weight.data, 0.0, gain) + elif init_type == "xavier": + init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == "xavier_uniform": + init.xavier_uniform_(m.weight.data, gain=1.0) + elif init_type == "kaiming": + init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") + elif init_type == "orthogonal": + init.orthogonal_(m.weight.data, gain=gain) + elif init_type == "none": # uses pytorch's default init method + m.reset_parameters() + else: + raise NotImplementedError("initialization method [%s] is not implemented" % init_type) + if hasattr(m, "bias") and m.bias is not None: + init.constant_(m.bias.data, 0.0) + + self.apply(init_func) + + # propagate to children + for m in self.children(): + if hasattr(m, "init_weights"): + m.init_weights(init_type, gain) diff --git a/Face_Enhancement/models/networks/encoder.py b/Face_Enhancement/models/networks/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..76acf690fd527bb9bd1dfc0c07c82573a1026d88 --- /dev/null +++ b/Face_Enhancement/models/networks/encoder.py @@ -0,0 +1,53 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch.nn as nn +import numpy as np +import torch.nn.functional as F +from models.networks.base_network import BaseNetwork +from models.networks.normalization import get_nonspade_norm_layer + + +class ConvEncoder(BaseNetwork): + """ Same architecture as the image discriminator """ + + def __init__(self, opt): + super().__init__() + + kw = 3 + pw = int(np.ceil((kw - 1.0) / 2)) + ndf = opt.ngf + norm_layer = get_nonspade_norm_layer(opt, opt.norm_E) + self.layer1 = norm_layer(nn.Conv2d(3, ndf, kw, stride=2, padding=pw)) + self.layer2 = norm_layer(nn.Conv2d(ndf * 1, ndf * 2, kw, stride=2, padding=pw)) + self.layer3 = norm_layer(nn.Conv2d(ndf * 2, ndf * 4, kw, stride=2, padding=pw)) + self.layer4 = norm_layer(nn.Conv2d(ndf * 4, ndf * 8, kw, stride=2, padding=pw)) + self.layer5 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw)) + if opt.crop_size >= 256: + self.layer6 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw)) + + self.so = s0 = 4 + self.fc_mu = nn.Linear(ndf * 8 * s0 * s0, 256) + self.fc_var = nn.Linear(ndf * 8 * s0 * s0, 256) + + self.actvn = nn.LeakyReLU(0.2, False) + self.opt = opt + + def forward(self, x): + if x.size(2) != 256 or x.size(3) != 256: + x = F.interpolate(x, size=(256, 256), mode="bilinear") + + x = self.layer1(x) + x = self.layer2(self.actvn(x)) + x = self.layer3(self.actvn(x)) + x = self.layer4(self.actvn(x)) + x = self.layer5(self.actvn(x)) + if self.opt.crop_size >= 256: + x = self.layer6(self.actvn(x)) + x = self.actvn(x) + + x = x.view(x.size(0), -1) + mu = self.fc_mu(x) + logvar = self.fc_var(x) + + return mu, logvar diff --git a/Face_Enhancement/models/networks/generator.py b/Face_Enhancement/models/networks/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..6e24cadc882caab9ee439bb3dd288e536878565a --- /dev/null +++ b/Face_Enhancement/models/networks/generator.py @@ -0,0 +1,233 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from models.networks.base_network import BaseNetwork +from models.networks.normalization import get_nonspade_norm_layer +from models.networks.architecture import ResnetBlock as ResnetBlock +from models.networks.architecture import SPADEResnetBlock as SPADEResnetBlock +from models.networks.architecture import SPADEResnetBlock_non_spade as SPADEResnetBlock_non_spade + + +class SPADEGenerator(BaseNetwork): + @staticmethod + def modify_commandline_options(parser, is_train): + parser.set_defaults(norm_G="spectralspadesyncbatch3x3") + parser.add_argument( + "--num_upsampling_layers", + choices=("normal", "more", "most"), + default="normal", + help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator", + ) + + return parser + + def __init__(self, opt): + super().__init__() + self.opt = opt + nf = opt.ngf + + self.sw, self.sh = self.compute_latent_vector_size(opt) + + print("The size of the latent vector size is [%d,%d]" % (self.sw, self.sh)) + + if opt.use_vae: + # In case of VAE, we will sample from random z vector + self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh) + else: + # Otherwise, we make the network deterministic by starting with + # downsampled segmentation map instead of random z + if self.opt.no_parsing_map: + self.fc = nn.Conv2d(3, 16 * nf, 3, padding=1) + else: + self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1) + + if self.opt.injection_layer == "all" or self.opt.injection_layer == "1": + self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt) + else: + self.head_0 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt) + + if self.opt.injection_layer == "all" or self.opt.injection_layer == "2": + self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt) + self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt) + + else: + self.G_middle_0 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt) + self.G_middle_1 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt) + + if self.opt.injection_layer == "all" or self.opt.injection_layer == "3": + self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt) + else: + self.up_0 = SPADEResnetBlock_non_spade(16 * nf, 8 * nf, opt) + + if self.opt.injection_layer == "all" or self.opt.injection_layer == "4": + self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt) + else: + self.up_1 = SPADEResnetBlock_non_spade(8 * nf, 4 * nf, opt) + + if self.opt.injection_layer == "all" or self.opt.injection_layer == "5": + self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt) + else: + self.up_2 = SPADEResnetBlock_non_spade(4 * nf, 2 * nf, opt) + + if self.opt.injection_layer == "all" or self.opt.injection_layer == "6": + self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt) + else: + self.up_3 = SPADEResnetBlock_non_spade(2 * nf, 1 * nf, opt) + + final_nc = nf + + if opt.num_upsampling_layers == "most": + self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt) + final_nc = nf // 2 + + self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1) + + self.up = nn.Upsample(scale_factor=2) + + def compute_latent_vector_size(self, opt): + if opt.num_upsampling_layers == "normal": + num_up_layers = 5 + elif opt.num_upsampling_layers == "more": + num_up_layers = 6 + elif opt.num_upsampling_layers == "most": + num_up_layers = 7 + else: + raise ValueError("opt.num_upsampling_layers [%s] not recognized" % opt.num_upsampling_layers) + + sw = opt.load_size // (2 ** num_up_layers) + sh = round(sw / opt.aspect_ratio) + + return sw, sh + + def forward(self, input, degraded_image, z=None): + seg = input + + if self.opt.use_vae: + # we sample z from unit normal and reshape the tensor + if z is None: + z = torch.randn(input.size(0), self.opt.z_dim, dtype=torch.float32, device=input.get_device()) + x = self.fc(z) + x = x.view(-1, 16 * self.opt.ngf, self.sh, self.sw) + else: + # we downsample segmap and run convolution + if self.opt.no_parsing_map: + x = F.interpolate(degraded_image, size=(self.sh, self.sw), mode="bilinear") + else: + x = F.interpolate(seg, size=(self.sh, self.sw), mode="nearest") + x = self.fc(x) + + x = self.head_0(x, seg, degraded_image) + + x = self.up(x) + x = self.G_middle_0(x, seg, degraded_image) + + if self.opt.num_upsampling_layers == "more" or self.opt.num_upsampling_layers == "most": + x = self.up(x) + + x = self.G_middle_1(x, seg, degraded_image) + + x = self.up(x) + x = self.up_0(x, seg, degraded_image) + x = self.up(x) + x = self.up_1(x, seg, degraded_image) + x = self.up(x) + x = self.up_2(x, seg, degraded_image) + x = self.up(x) + x = self.up_3(x, seg, degraded_image) + + if self.opt.num_upsampling_layers == "most": + x = self.up(x) + x = self.up_4(x, seg, degraded_image) + + x = self.conv_img(F.leaky_relu(x, 2e-1)) + x = F.tanh(x) + + return x + + +class Pix2PixHDGenerator(BaseNetwork): + @staticmethod + def modify_commandline_options(parser, is_train): + parser.add_argument( + "--resnet_n_downsample", type=int, default=4, help="number of downsampling layers in netG" + ) + parser.add_argument( + "--resnet_n_blocks", + type=int, + default=9, + help="number of residual blocks in the global generator network", + ) + parser.add_argument( + "--resnet_kernel_size", type=int, default=3, help="kernel size of the resnet block" + ) + parser.add_argument( + "--resnet_initial_kernel_size", type=int, default=7, help="kernel size of the first convolution" + ) + # parser.set_defaults(norm_G='instance') + return parser + + def __init__(self, opt): + super().__init__() + input_nc = 3 + + # print("xxxxx") + # print(opt.norm_G) + norm_layer = get_nonspade_norm_layer(opt, opt.norm_G) + activation = nn.ReLU(False) + + model = [] + + # initial conv + model += [ + nn.ReflectionPad2d(opt.resnet_initial_kernel_size // 2), + norm_layer(nn.Conv2d(input_nc, opt.ngf, kernel_size=opt.resnet_initial_kernel_size, padding=0)), + activation, + ] + + # downsample + mult = 1 + for i in range(opt.resnet_n_downsample): + model += [ + norm_layer(nn.Conv2d(opt.ngf * mult, opt.ngf * mult * 2, kernel_size=3, stride=2, padding=1)), + activation, + ] + mult *= 2 + + # resnet blocks + for i in range(opt.resnet_n_blocks): + model += [ + ResnetBlock( + opt.ngf * mult, + norm_layer=norm_layer, + activation=activation, + kernel_size=opt.resnet_kernel_size, + ) + ] + + # upsample + for i in range(opt.resnet_n_downsample): + nc_in = int(opt.ngf * mult) + nc_out = int((opt.ngf * mult) / 2) + model += [ + norm_layer( + nn.ConvTranspose2d(nc_in, nc_out, kernel_size=3, stride=2, padding=1, output_padding=1) + ), + activation, + ] + mult = mult // 2 + + # final output conv + model += [ + nn.ReflectionPad2d(3), + nn.Conv2d(nc_out, opt.output_nc, kernel_size=7, padding=0), + nn.Tanh(), + ] + + self.model = nn.Sequential(*model) + + def forward(self, input, degraded_image, z=None): + return self.model(degraded_image) + diff --git a/Face_Enhancement/models/networks/normalization.py b/Face_Enhancement/models/networks/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..a865db0290c7159c6e641bbc52e14fbc79dde289 --- /dev/null +++ b/Face_Enhancement/models/networks/normalization.py @@ -0,0 +1,100 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import re +import torch +import torch.nn as nn +import torch.nn.functional as F +from models.networks.sync_batchnorm import SynchronizedBatchNorm2d +import torch.nn.utils.spectral_norm as spectral_norm + + +def get_nonspade_norm_layer(opt, norm_type="instance"): + # helper function to get # output channels of the previous layer + def get_out_channel(layer): + if hasattr(layer, "out_channels"): + return getattr(layer, "out_channels") + return layer.weight.size(0) + + # this function will be returned + def add_norm_layer(layer): + nonlocal norm_type + if norm_type.startswith("spectral"): + layer = spectral_norm(layer) + subnorm_type = norm_type[len("spectral") :] + + if subnorm_type == "none" or len(subnorm_type) == 0: + return layer + + # remove bias in the previous layer, which is meaningless + # since it has no effect after normalization + if getattr(layer, "bias", None) is not None: + delattr(layer, "bias") + layer.register_parameter("bias", None) + + if subnorm_type == "batch": + norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) + elif subnorm_type == "sync_batch": + norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True) + elif subnorm_type == "instance": + norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) + else: + raise ValueError("normalization layer %s is not recognized" % subnorm_type) + + return nn.Sequential(layer, norm_layer) + + return add_norm_layer + + +class SPADE(nn.Module): + def __init__(self, config_text, norm_nc, label_nc, opt): + super().__init__() + + assert config_text.startswith("spade") + parsed = re.search("spade(\D+)(\d)x\d", config_text) + param_free_norm_type = str(parsed.group(1)) + ks = int(parsed.group(2)) + self.opt = opt + if param_free_norm_type == "instance": + self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) + elif param_free_norm_type == "syncbatch": + self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False) + elif param_free_norm_type == "batch": + self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) + else: + raise ValueError("%s is not a recognized param-free norm type in SPADE" % param_free_norm_type) + + # The dimension of the intermediate embedding space. Yes, hardcoded. + nhidden = 128 + + pw = ks // 2 + + if self.opt.no_parsing_map: + self.mlp_shared = nn.Sequential(nn.Conv2d(3, nhidden, kernel_size=ks, padding=pw), nn.ReLU()) + else: + self.mlp_shared = nn.Sequential( + nn.Conv2d(label_nc + 3, nhidden, kernel_size=ks, padding=pw), nn.ReLU() + ) + self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) + self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) + + def forward(self, x, segmap, degraded_image): + + # Part 1. generate parameter-free normalized activations + normalized = self.param_free_norm(x) + + # Part 2. produce scaling and bias conditioned on semantic map + segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest") + degraded_face = F.interpolate(degraded_image, size=x.size()[2:], mode="bilinear") + + if self.opt.no_parsing_map: + actv = self.mlp_shared(degraded_face) + else: + actv = self.mlp_shared(torch.cat((segmap, degraded_face), dim=1)) + gamma = self.mlp_gamma(actv) + beta = self.mlp_beta(actv) + + # apply scale and bias + out = normalized * (1 + gamma) + beta + + return out diff --git a/Face_Enhancement/models/networks/sync_batchnorm/__init__.py b/Face_Enhancement/models/networks/sync_batchnorm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6d9b36c74b1808b56ded68cf080a689db7e0ee4e --- /dev/null +++ b/Face_Enhancement/models/networks/sync_batchnorm/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# File : __init__.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +from .batchnorm import set_sbn_eps_mode +from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d +from .batchnorm import patch_sync_batchnorm, convert_model +from .replicate import DataParallelWithCallback, patch_replication_callback diff --git a/Face_Enhancement/models/networks/sync_batchnorm/batchnorm.py b/Face_Enhancement/models/networks/sync_batchnorm/batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..bf8d7a7325b474771a11a137053971fd40426079 --- /dev/null +++ b/Face_Enhancement/models/networks/sync_batchnorm/batchnorm.py @@ -0,0 +1,412 @@ +# -*- coding: utf-8 -*- +# File : batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import collections +import contextlib + +import torch +import torch.nn.functional as F + +from torch.nn.modules.batchnorm import _BatchNorm + +try: + from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast +except ImportError: + ReduceAddCoalesced = Broadcast = None + +try: + from jactorch.parallel.comm import SyncMaster + from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback +except ImportError: + from .comm import SyncMaster + from .replicate import DataParallelWithCallback + +__all__ = [ + 'set_sbn_eps_mode', + 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d', + 'patch_sync_batchnorm', 'convert_model' +] + + +SBN_EPS_MODE = 'clamp' + + +def set_sbn_eps_mode(mode): + global SBN_EPS_MODE + assert mode in ('clamp', 'plus') + SBN_EPS_MODE = mode + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dimensions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) + + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): + assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.' + + super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, + track_running_stats=track_running_stats) + + if not self.track_running_stats: + import warnings + warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.') + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + def forward(self, input): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + + # Resize the input to (B, C, -1). + input_shape = input.size() + assert input.size(1) == self.num_features, 'Channel size mismatch: got {}, expect {}.'.format(input.size(1), self.num_features) + input = input.view(input.size(0), self.num_features, -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input ** 2) + + # Reduce-and-broadcast the statistics. + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + + # Compute the output. + if self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + + # Always using same "device order" makes the ReduceAdd operation faster. + # Thanks to:: Tete Xiao (http://tetexiao.com/) + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) + + return outputs + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + if hasattr(torch, 'no_grad'): + with torch.no_grad(): + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + else: + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + + if SBN_EPS_MODE == 'clamp': + return mean, bias_var.clamp(self.eps) ** -0.5 + elif SBN_EPS_MODE == 'plus': + return mean, (bias_var + self.eps) ** -0.5 + else: + raise ValueError('Unknown EPS mode: {}.'.format(SBN_EPS_MODE)) + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + + +@contextlib.contextmanager +def patch_sync_batchnorm(): + import torch.nn as nn + + backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d + + nn.BatchNorm1d = SynchronizedBatchNorm1d + nn.BatchNorm2d = SynchronizedBatchNorm2d + nn.BatchNorm3d = SynchronizedBatchNorm3d + + yield + + nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup + + +def convert_model(module): + """Traverse the input module and its child recursively + and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d + to SynchronizedBatchNorm*N*d + + Args: + module: the input module needs to be convert to SyncBN model + + Examples: + >>> import torch.nn as nn + >>> import torchvision + >>> # m is a standard pytorch model + >>> m = torchvision.models.resnet18(True) + >>> m = nn.DataParallel(m) + >>> # after convert, m is using SyncBN + >>> m = convert_model(m) + """ + if isinstance(module, torch.nn.DataParallel): + mod = module.module + mod = convert_model(mod) + mod = DataParallelWithCallback(mod, device_ids=module.device_ids) + return mod + + mod = module + for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d, + torch.nn.modules.batchnorm.BatchNorm2d, + torch.nn.modules.batchnorm.BatchNorm3d], + [SynchronizedBatchNorm1d, + SynchronizedBatchNorm2d, + SynchronizedBatchNorm3d]): + if isinstance(module, pth_module): + mod = sync_module(module.num_features, module.eps, module.momentum, module.affine) + mod.running_mean = module.running_mean + mod.running_var = module.running_var + if module.affine: + mod.weight.data = module.weight.data.clone().detach() + mod.bias.data = module.bias.data.clone().detach() + + for name, child in module.named_children(): + mod.add_module(name, convert_model(child)) + + return mod diff --git a/Face_Enhancement/models/networks/sync_batchnorm/batchnorm_reimpl.py b/Face_Enhancement/models/networks/sync_batchnorm/batchnorm_reimpl.py new file mode 100644 index 0000000000000000000000000000000000000000..18145c3353e13d482c492ae46df91a537669fca0 --- /dev/null +++ b/Face_Enhancement/models/networks/sync_batchnorm/batchnorm_reimpl.py @@ -0,0 +1,74 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# File : batchnorm_reimpl.py +# Author : acgtyrant +# Date : 11/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import torch +import torch.nn as nn +import torch.nn.init as init + +__all__ = ['BatchNorm2dReimpl'] + + +class BatchNorm2dReimpl(nn.Module): + """ + A re-implementation of batch normalization, used for testing the numerical + stability. + + Author: acgtyrant + See also: + https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 + """ + def __init__(self, num_features, eps=1e-5, momentum=0.1): + super().__init__() + + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.weight = nn.Parameter(torch.empty(num_features)) + self.bias = nn.Parameter(torch.empty(num_features)) + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_running_stats(self): + self.running_mean.zero_() + self.running_var.fill_(1) + + def reset_parameters(self): + self.reset_running_stats() + init.uniform_(self.weight) + init.zeros_(self.bias) + + def forward(self, input_): + batchsize, channels, height, width = input_.size() + numel = batchsize * height * width + input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) + sum_ = input_.sum(1) + sum_of_square = input_.pow(2).sum(1) + mean = sum_ / numel + sumvar = sum_of_square - sum_ * mean + + self.running_mean = ( + (1 - self.momentum) * self.running_mean + + self.momentum * mean.detach() + ) + unbias_var = sumvar / (numel - 1) + self.running_var = ( + (1 - self.momentum) * self.running_var + + self.momentum * unbias_var.detach() + ) + + bias_var = sumvar / numel + inv_std = 1 / (bias_var + self.eps).pow(0.5) + output = ( + (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * + self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) + + return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() + diff --git a/Face_Enhancement/models/networks/sync_batchnorm/comm.py b/Face_Enhancement/models/networks/sync_batchnorm/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..922f8c4a3adaa9b32fdcaef09583be03b0d7eb2b --- /dev/null +++ b/Face_Enhancement/models/networks/sync_batchnorm/comm.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +# File : comm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def __getstate__(self): + return {'master_callback': self._master_callback} + + def __setstate__(self, state): + self.__init__(state['master_callback']) + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/Face_Enhancement/models/networks/sync_batchnorm/replicate.py b/Face_Enhancement/models/networks/sync_batchnorm/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..b71c7b8ed51a1d6c55b1f753bdd8d90bad79bd06 --- /dev/null +++ b/Face_Enhancement/models/networks/sync_batchnorm/replicate.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# File : replicate.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import functools + +from torch.nn.parallel.data_parallel import DataParallel + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate diff --git a/Face_Enhancement/models/networks/sync_batchnorm/unittest.py b/Face_Enhancement/models/networks/sync_batchnorm/unittest.py new file mode 100644 index 0000000000000000000000000000000000000000..998223a0e0242dc4a5b2fcd74af79dc7232794da --- /dev/null +++ b/Face_Enhancement/models/networks/sync_batchnorm/unittest.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# File : unittest.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import unittest +import torch + + +class TorchTestCase(unittest.TestCase): + def assertTensorClose(self, x, y): + adiff = float((x - y).abs().max()) + if (y == 0).all(): + rdiff = 'NaN' + else: + rdiff = float((adiff / y).abs().max()) + + message = ( + 'Tensor close check failed\n' + 'adiff={}\n' + 'rdiff={}\n' + ).format(adiff, rdiff) + self.assertTrue(torch.allclose(x, y, atol=1e-5, rtol=1e-3), message) + diff --git a/Face_Enhancement/models/pix2pix_model.py b/Face_Enhancement/models/pix2pix_model.py new file mode 100644 index 0000000000000000000000000000000000000000..41d6df671752f11ab7001d5b1b3e82034c2e6493 --- /dev/null +++ b/Face_Enhancement/models/pix2pix_model.py @@ -0,0 +1,246 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import models.networks as networks +import util.util as util + + +class Pix2PixModel(torch.nn.Module): + @staticmethod + def modify_commandline_options(parser, is_train): + networks.modify_commandline_options(parser, is_train) + return parser + + def __init__(self, opt): + super().__init__() + self.opt = opt + self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() else torch.FloatTensor + self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() else torch.ByteTensor + + self.netG, self.netD, self.netE = self.initialize_networks(opt) + + # set loss functions + if opt.isTrain: + self.criterionGAN = networks.GANLoss(opt.gan_mode, tensor=self.FloatTensor, opt=self.opt) + self.criterionFeat = torch.nn.L1Loss() + if not opt.no_vgg_loss: + self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids) + if opt.use_vae: + self.KLDLoss = networks.KLDLoss() + + # Entry point for all calls involving forward pass + # of deep networks. We used this approach since DataParallel module + # can't parallelize custom functions, we branch to different + # routines based on |mode|. + def forward(self, data, mode): + input_semantics, real_image, degraded_image = self.preprocess_input(data) + + if mode == "generator": + g_loss, generated = self.compute_generator_loss(input_semantics, degraded_image, real_image) + return g_loss, generated + elif mode == "discriminator": + d_loss = self.compute_discriminator_loss(input_semantics, degraded_image, real_image) + return d_loss + elif mode == "encode_only": + z, mu, logvar = self.encode_z(real_image) + return mu, logvar + elif mode == "inference": + with torch.no_grad(): + fake_image, _ = self.generate_fake(input_semantics, degraded_image, real_image) + return fake_image + else: + raise ValueError("|mode| is invalid") + + def create_optimizers(self, opt): + G_params = list(self.netG.parameters()) + if opt.use_vae: + G_params += list(self.netE.parameters()) + if opt.isTrain: + D_params = list(self.netD.parameters()) + + beta1, beta2 = opt.beta1, opt.beta2 + if opt.no_TTUR: + G_lr, D_lr = opt.lr, opt.lr + else: + G_lr, D_lr = opt.lr / 2, opt.lr * 2 + + optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2)) + optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2)) + + return optimizer_G, optimizer_D + + def save(self, epoch): + util.save_network(self.netG, "G", epoch, self.opt) + util.save_network(self.netD, "D", epoch, self.opt) + if self.opt.use_vae: + util.save_network(self.netE, "E", epoch, self.opt) + + ############################################################################ + # Private helper methods + ############################################################################ + + def initialize_networks(self, opt): + netG = networks.define_G(opt) + netD = networks.define_D(opt) if opt.isTrain else None + netE = networks.define_E(opt) if opt.use_vae else None + + if not opt.isTrain or opt.continue_train: + netG = util.load_network(netG, "G", opt.which_epoch, opt) + if opt.isTrain: + netD = util.load_network(netD, "D", opt.which_epoch, opt) + if opt.use_vae: + netE = util.load_network(netE, "E", opt.which_epoch, opt) + + return netG, netD, netE + + # preprocess the input, such as moving the tensors to GPUs and + # transforming the label map to one-hot encoding + # |data|: dictionary of the input data + + def preprocess_input(self, data): + # move to GPU and change data types + # data['label'] = data['label'].long() + + if not self.opt.isTrain: + if self.use_gpu(): + data["label"] = data["label"].cuda() + data["image"] = data["image"].cuda() + return data["label"], data["image"], data["image"] + + ## While testing, the input image is the degraded face + if self.use_gpu(): + data["label"] = data["label"].cuda() + data["degraded_image"] = data["degraded_image"].cuda() + data["image"] = data["image"].cuda() + + # # create one-hot label map + # label_map = data['label'] + # bs, _, h, w = label_map.size() + # nc = self.opt.label_nc + 1 if self.opt.contain_dontcare_label \ + # else self.opt.label_nc + # input_label = self.FloatTensor(bs, nc, h, w).zero_() + # input_semantics = input_label.scatter_(1, label_map, 1.0) + + return data["label"], data["image"], data["degraded_image"] + + def compute_generator_loss(self, input_semantics, degraded_image, real_image): + G_losses = {} + + fake_image, KLD_loss = self.generate_fake( + input_semantics, degraded_image, real_image, compute_kld_loss=self.opt.use_vae + ) + + if self.opt.use_vae: + G_losses["KLD"] = KLD_loss + + pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image) + + G_losses["GAN"] = self.criterionGAN(pred_fake, True, for_discriminator=False) + + if not self.opt.no_ganFeat_loss: + num_D = len(pred_fake) + GAN_Feat_loss = self.FloatTensor(1).fill_(0) + for i in range(num_D): # for each discriminator + # last output is the final prediction, so we exclude it + num_intermediate_outputs = len(pred_fake[i]) - 1 + for j in range(num_intermediate_outputs): # for each layer output + unweighted_loss = self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) + GAN_Feat_loss += unweighted_loss * self.opt.lambda_feat / num_D + G_losses["GAN_Feat"] = GAN_Feat_loss + + if not self.opt.no_vgg_loss: + G_losses["VGG"] = self.criterionVGG(fake_image, real_image) * self.opt.lambda_vgg + + return G_losses, fake_image + + def compute_discriminator_loss(self, input_semantics, degraded_image, real_image): + D_losses = {} + with torch.no_grad(): + fake_image, _ = self.generate_fake(input_semantics, degraded_image, real_image) + fake_image = fake_image.detach() + fake_image.requires_grad_() + + pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image) + + D_losses["D_Fake"] = self.criterionGAN(pred_fake, False, for_discriminator=True) + D_losses["D_real"] = self.criterionGAN(pred_real, True, for_discriminator=True) + + return D_losses + + def encode_z(self, real_image): + mu, logvar = self.netE(real_image) + z = self.reparameterize(mu, logvar) + return z, mu, logvar + + def generate_fake(self, input_semantics, degraded_image, real_image, compute_kld_loss=False): + z = None + KLD_loss = None + if self.opt.use_vae: + z, mu, logvar = self.encode_z(real_image) + if compute_kld_loss: + KLD_loss = self.KLDLoss(mu, logvar) * self.opt.lambda_kld + + fake_image = self.netG(input_semantics, degraded_image, z=z) + + assert ( + not compute_kld_loss + ) or self.opt.use_vae, "You cannot compute KLD loss if opt.use_vae == False" + + return fake_image, KLD_loss + + # Given fake and real image, return the prediction of discriminator + # for each fake and real image. + + def discriminate(self, input_semantics, fake_image, real_image): + + if self.opt.no_parsing_map: + fake_concat = fake_image + real_concat = real_image + else: + fake_concat = torch.cat([input_semantics, fake_image], dim=1) + real_concat = torch.cat([input_semantics, real_image], dim=1) + + # In Batch Normalization, the fake and real images are + # recommended to be in the same batch to avoid disparate + # statistics in fake and real images. + # So both fake and real images are fed to D all at once. + fake_and_real = torch.cat([fake_concat, real_concat], dim=0) + + discriminator_out = self.netD(fake_and_real) + + pred_fake, pred_real = self.divide_pred(discriminator_out) + + return pred_fake, pred_real + + # Take the prediction of fake and real images from the combined batch + def divide_pred(self, pred): + # the prediction contains the intermediate outputs of multiscale GAN, + # so it's usually a list + if type(pred) == list: + fake = [] + real = [] + for p in pred: + fake.append([tensor[: tensor.size(0) // 2] for tensor in p]) + real.append([tensor[tensor.size(0) // 2 :] for tensor in p]) + else: + fake = pred[: pred.size(0) // 2] + real = pred[pred.size(0) // 2 :] + + return fake, real + + def get_edges(self, t): + edge = self.ByteTensor(t.size()).zero_() + edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1]) + edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1]) + edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) + edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) + return edge.float() + + def reparameterize(self, mu, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return eps.mul(std) + mu + + def use_gpu(self): + return len(self.opt.gpu_ids) > 0 diff --git a/Face_Enhancement/options/__init__.py b/Face_Enhancement/options/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..59e481eb93dda48c81e04dd491cd3c9190c8eeb4 --- /dev/null +++ b/Face_Enhancement/options/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/Face_Enhancement/options/base_options.py b/Face_Enhancement/options/base_options.py new file mode 100644 index 0000000000000000000000000000000000000000..af67450092b2428c71f2a40941efe612378ea0bd --- /dev/null +++ b/Face_Enhancement/options/base_options.py @@ -0,0 +1,294 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import sys +import argparse +import os +from util import util +import torch +import models +import data +import pickle + + +class BaseOptions: + def __init__(self): + self.initialized = False + + def initialize(self, parser): + # experiment specifics + parser.add_argument( + "--name", + type=str, + default="label2coco", + help="name of the experiment. It decides where to store samples and models", + ) + + parser.add_argument( + "--gpu_ids", type=str, default="0", help="gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU" + ) + parser.add_argument( + "--checkpoints_dir", type=str, default="./checkpoints", help="models are saved here" + ) + parser.add_argument("--model", type=str, default="pix2pix", help="which model to use") + parser.add_argument( + "--norm_G", + type=str, + default="spectralinstance", + help="instance normalization or batch normalization", + ) + parser.add_argument( + "--norm_D", + type=str, + default="spectralinstance", + help="instance normalization or batch normalization", + ) + parser.add_argument( + "--norm_E", + type=str, + default="spectralinstance", + help="instance normalization or batch normalization", + ) + parser.add_argument("--phase", type=str, default="train", help="train, val, test, etc") + + # input/output sizes + parser.add_argument("--batchSize", type=int, default=1, help="input batch size") + parser.add_argument( + "--preprocess_mode", + type=str, + default="scale_width_and_crop", + help="scaling and cropping of images at load time.", + choices=( + "resize_and_crop", + "crop", + "scale_width", + "scale_width_and_crop", + "scale_shortside", + "scale_shortside_and_crop", + "fixed", + "none", + "resize", + ), + ) + parser.add_argument( + "--load_size", + type=int, + default=1024, + help="Scale images to this size. The final image will be cropped to --crop_size.", + ) + parser.add_argument( + "--crop_size", + type=int, + default=512, + help="Crop to the width of crop_size (after initially scaling the images to load_size.)", + ) + parser.add_argument( + "--aspect_ratio", + type=float, + default=1.0, + help="The ratio width/height. The final height of the load image will be crop_size/aspect_ratio", + ) + parser.add_argument( + "--label_nc", + type=int, + default=182, + help="# of input label classes without unknown class. If you have unknown class as class label, specify --contain_dopntcare_label.", + ) + parser.add_argument( + "--contain_dontcare_label", + action="store_true", + help="if the label map contains dontcare label (dontcare=255)", + ) + parser.add_argument("--output_nc", type=int, default=3, help="# of output image channels") + + # for setting inputs + parser.add_argument("--dataroot", type=str, default="./datasets/cityscapes/") + parser.add_argument("--dataset_mode", type=str, default="coco") + parser.add_argument( + "--serial_batches", + action="store_true", + help="if true, takes images in order to make batches, otherwise takes them randomly", + ) + parser.add_argument( + "--no_flip", + action="store_true", + help="if specified, do not flip the images for data argumentation", + ) + parser.add_argument("--nThreads", default=0, type=int, help="# threads for loading data") + parser.add_argument( + "--max_dataset_size", + type=int, + default=sys.maxsize, + help="Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.", + ) + parser.add_argument( + "--load_from_opt_file", + action="store_true", + help="load the options from checkpoints and use that as default", + ) + parser.add_argument( + "--cache_filelist_write", + action="store_true", + help="saves the current filelist into a text file, so that it loads faster", + ) + parser.add_argument( + "--cache_filelist_read", action="store_true", help="reads from the file list cache" + ) + + # for displays + parser.add_argument("--display_winsize", type=int, default=400, help="display window size") + + # for generator + parser.add_argument( + "--netG", type=str, default="spade", help="selects model to use for netG (pix2pixhd | spade)" + ) + parser.add_argument("--ngf", type=int, default=64, help="# of gen filters in first conv layer") + parser.add_argument( + "--init_type", + type=str, + default="xavier", + help="network initialization [normal|xavier|kaiming|orthogonal]", + ) + parser.add_argument( + "--init_variance", type=float, default=0.02, help="variance of the initialization distribution" + ) + parser.add_argument("--z_dim", type=int, default=256, help="dimension of the latent z vector") + parser.add_argument( + "--no_parsing_map", action="store_true", help="During training, we do not use the parsing map" + ) + + # for instance-wise features + parser.add_argument( + "--no_instance", action="store_true", help="if specified, do *not* add instance map as input" + ) + parser.add_argument( + "--nef", type=int, default=16, help="# of encoder filters in the first conv layer" + ) + parser.add_argument("--use_vae", action="store_true", help="enable training with an image encoder.") + parser.add_argument( + "--tensorboard_log", action="store_true", help="use tensorboard to record the resutls" + ) + + # parser.add_argument('--img_dir',) + parser.add_argument( + "--old_face_folder", type=str, default="", help="The folder name of input old face" + ) + parser.add_argument( + "--old_face_label_folder", type=str, default="", help="The folder name of input old face label" + ) + + parser.add_argument("--injection_layer", type=str, default="all", help="") + + self.initialized = True + return parser + + def gather_options(self): + # initialize parser with basic options + if not self.initialized: + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = self.initialize(parser) + + # get the basic options + opt, unknown = parser.parse_known_args() + + # modify model-related parser options + model_name = opt.model + model_option_setter = models.get_option_setter(model_name) + parser = model_option_setter(parser, self.isTrain) + + # modify dataset-related parser options + # dataset_mode = opt.dataset_mode + # dataset_option_setter = data.get_option_setter(dataset_mode) + # parser = dataset_option_setter(parser, self.isTrain) + + opt, unknown = parser.parse_known_args() + + # if there is opt_file, load it. + # The previous default options will be overwritten + if opt.load_from_opt_file: + parser = self.update_options_from_file(parser, opt) + + opt = parser.parse_args() + self.parser = parser + return opt + + def print_options(self, opt): + message = "" + message += "----------------- Options ---------------\n" + for k, v in sorted(vars(opt).items()): + comment = "" + default = self.parser.get_default(k) + if v != default: + comment = "\t[default: %s]" % str(default) + message += "{:>25}: {:<30}{}\n".format(str(k), str(v), comment) + message += "----------------- End -------------------" + # print(message) + + def option_file_path(self, opt, makedir=False): + expr_dir = os.path.join(opt.checkpoints_dir, opt.name) + if makedir: + util.mkdirs(expr_dir) + file_name = os.path.join(expr_dir, "opt") + return file_name + + def save_options(self, opt): + file_name = self.option_file_path(opt, makedir=True) + with open(file_name + ".txt", "wt") as opt_file: + for k, v in sorted(vars(opt).items()): + comment = "" + default = self.parser.get_default(k) + if v != default: + comment = "\t[default: %s]" % str(default) + opt_file.write("{:>25}: {:<30}{}\n".format(str(k), str(v), comment)) + + with open(file_name + ".pkl", "wb") as opt_file: + pickle.dump(opt, opt_file) + + def update_options_from_file(self, parser, opt): + new_opt = self.load_options(opt) + for k, v in sorted(vars(opt).items()): + if hasattr(new_opt, k) and v != getattr(new_opt, k): + new_val = getattr(new_opt, k) + parser.set_defaults(**{k: new_val}) + return parser + + def load_options(self, opt): + file_name = self.option_file_path(opt, makedir=False) + new_opt = pickle.load(open(file_name + ".pkl", "rb")) + return new_opt + + def parse(self, save=False): + + opt = self.gather_options() + opt.isTrain = self.isTrain # train or test + opt.contain_dontcare_label = False + + self.print_options(opt) + if opt.isTrain: + self.save_options(opt) + + # Set semantic_nc based on the option. + # This will be convenient in many places + opt.semantic_nc = ( + opt.label_nc + (1 if opt.contain_dontcare_label else 0) + (0 if opt.no_instance else 1) + ) + + # set gpu ids + str_ids = opt.gpu_ids.split(",") + opt.gpu_ids = [] + for str_id in str_ids: + int_id = int(str_id) + if int_id >= 0: + opt.gpu_ids.append(int_id) + + if len(opt.gpu_ids) > 0: + print("The main GPU is ") + print(opt.gpu_ids[0]) + torch.cuda.set_device(opt.gpu_ids[0]) + + assert ( + len(opt.gpu_ids) == 0 or opt.batchSize % len(opt.gpu_ids) == 0 + ), "Batch size %d is wrong. It must be a multiple of # GPUs %d." % (opt.batchSize, len(opt.gpu_ids)) + + self.opt = opt + return self.opt diff --git a/Face_Enhancement/options/test_options.py b/Face_Enhancement/options/test_options.py new file mode 100644 index 0000000000000000000000000000000000000000..4fc38bd4ed4e2e15b05f7263ac6f906adb7e5ff9 --- /dev/null +++ b/Face_Enhancement/options/test_options.py @@ -0,0 +1,26 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .base_options import BaseOptions + + +class TestOptions(BaseOptions): + def initialize(self, parser): + BaseOptions.initialize(self, parser) + parser.add_argument("--results_dir", type=str, default="./results/", help="saves results here.") + parser.add_argument( + "--which_epoch", + type=str, + default="latest", + help="which epoch to load? set to latest to use latest cached model", + ) + parser.add_argument("--how_many", type=int, default=float("inf"), help="how many test images to run") + + parser.set_defaults( + preprocess_mode="scale_width_and_crop", crop_size=256, load_size=256, display_winsize=256 + ) + parser.set_defaults(serial_batches=True) + parser.set_defaults(no_flip=True) + parser.set_defaults(phase="test") + self.isTrain = False + return parser diff --git a/Face_Enhancement/requirements.txt b/Face_Enhancement/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..29cee16d65ba3532854bb42f799c38e11169ad75 --- /dev/null +++ b/Face_Enhancement/requirements.txt @@ -0,0 +1,9 @@ +torch>=1.0.0 +torchvision +dominate>=2.3.1 +wandb +dill +scikit-image +tensorboardX +scipy +opencv-python \ No newline at end of file diff --git a/Face_Enhancement/test_face.py b/Face_Enhancement/test_face.py new file mode 100644 index 0000000000000000000000000000000000000000..4e79e1fbf590ae863eb34d6ee432d4ef2e5a54cf --- /dev/null +++ b/Face_Enhancement/test_face.py @@ -0,0 +1,45 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os +from collections import OrderedDict + +import data +from options.test_options import TestOptions +from models.pix2pix_model import Pix2PixModel +from util.visualizer import Visualizer +import torchvision.utils as vutils +import warnings +warnings.filterwarnings("ignore", category=UserWarning) + +opt = TestOptions().parse() + +dataloader = data.create_dataloader(opt) + +model = Pix2PixModel(opt) +model.eval() + +visualizer = Visualizer(opt) + + +single_save_url = os.path.join(opt.checkpoints_dir, opt.name, opt.results_dir, "each_img") + + +if not os.path.exists(single_save_url): + os.makedirs(single_save_url) + + +for i, data_i in enumerate(dataloader): + if i * opt.batchSize >= opt.how_many: + break + + generated = model(data_i, mode="inference") + + img_path = data_i["path"] + + for b in range(generated.shape[0]): + img_name = os.path.split(img_path[b])[-1] + save_img_url = os.path.join(single_save_url, img_name) + + vutils.save_image((generated[b] + 1) / 2, save_img_url) + diff --git a/Face_Enhancement/util/__init__.py b/Face_Enhancement/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..59e481eb93dda48c81e04dd491cd3c9190c8eeb4 --- /dev/null +++ b/Face_Enhancement/util/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/Face_Enhancement/util/iter_counter.py b/Face_Enhancement/util/iter_counter.py new file mode 100644 index 0000000000000000000000000000000000000000..277bb67bb41d613e21f8400e3733490909ff73b7 --- /dev/null +++ b/Face_Enhancement/util/iter_counter.py @@ -0,0 +1,74 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os +import time +import numpy as np + + +# Helper class that keeps track of training iterations +class IterationCounter: + def __init__(self, opt, dataset_size): + self.opt = opt + self.dataset_size = dataset_size + + self.first_epoch = 1 + self.total_epochs = opt.niter + opt.niter_decay + self.epoch_iter = 0 # iter number within each epoch + self.iter_record_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, "iter.txt") + if opt.isTrain and opt.continue_train: + try: + self.first_epoch, self.epoch_iter = np.loadtxt( + self.iter_record_path, delimiter=",", dtype=int + ) + print("Resuming from epoch %d at iteration %d" % (self.first_epoch, self.epoch_iter)) + except: + print( + "Could not load iteration record at %s. Starting from beginning." % self.iter_record_path + ) + + self.total_steps_so_far = (self.first_epoch - 1) * dataset_size + self.epoch_iter + + # return the iterator of epochs for the training + def training_epochs(self): + return range(self.first_epoch, self.total_epochs + 1) + + def record_epoch_start(self, epoch): + self.epoch_start_time = time.time() + self.epoch_iter = 0 + self.last_iter_time = time.time() + self.current_epoch = epoch + + def record_one_iteration(self): + current_time = time.time() + + # the last remaining batch is dropped (see data/__init__.py), + # so we can assume batch size is always opt.batchSize + self.time_per_iter = (current_time - self.last_iter_time) / self.opt.batchSize + self.last_iter_time = current_time + self.total_steps_so_far += self.opt.batchSize + self.epoch_iter += self.opt.batchSize + + def record_epoch_end(self): + current_time = time.time() + self.time_per_epoch = current_time - self.epoch_start_time + print( + "End of epoch %d / %d \t Time Taken: %d sec" + % (self.current_epoch, self.total_epochs, self.time_per_epoch) + ) + if self.current_epoch % self.opt.save_epoch_freq == 0: + np.savetxt(self.iter_record_path, (self.current_epoch + 1, 0), delimiter=",", fmt="%d") + print("Saved current iteration count at %s." % self.iter_record_path) + + def record_current_iter(self): + np.savetxt(self.iter_record_path, (self.current_epoch, self.epoch_iter), delimiter=",", fmt="%d") + print("Saved current iteration count at %s." % self.iter_record_path) + + def needs_saving(self): + return (self.total_steps_so_far % self.opt.save_latest_freq) < self.opt.batchSize + + def needs_printing(self): + return (self.total_steps_so_far % self.opt.print_freq) < self.opt.batchSize + + def needs_displaying(self): + return (self.total_steps_so_far % self.opt.display_freq) < self.opt.batchSize diff --git a/Face_Enhancement/util/util.py b/Face_Enhancement/util/util.py new file mode 100644 index 0000000000000000000000000000000000000000..e18b4a26082449977b27a4c1506649a2447988b1 --- /dev/null +++ b/Face_Enhancement/util/util.py @@ -0,0 +1,210 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import re +import importlib +import torch +from argparse import Namespace +import numpy as np +from PIL import Image +import os +import argparse +import dill as pickle + + +def save_obj(obj, name): + with open(name, "wb") as f: + pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) + + +def load_obj(name): + with open(name, "rb") as f: + return pickle.load(f) + + +def copyconf(default_opt, **kwargs): + conf = argparse.Namespace(**vars(default_opt)) + for key in kwargs: + print(key, kwargs[key]) + setattr(conf, key, kwargs[key]) + return conf + + +# Converts a Tensor into a Numpy array +# |imtype|: the desired type of the converted numpy array +def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False): + if isinstance(image_tensor, list): + image_numpy = [] + for i in range(len(image_tensor)): + image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) + return image_numpy + + if image_tensor.dim() == 4: + # transform each image in the batch + images_np = [] + for b in range(image_tensor.size(0)): + one_image = image_tensor[b] + one_image_np = tensor2im(one_image) + images_np.append(one_image_np.reshape(1, *one_image_np.shape)) + images_np = np.concatenate(images_np, axis=0) + + return images_np + + if image_tensor.dim() == 2: + image_tensor = image_tensor.unsqueeze(0) + image_numpy = image_tensor.detach().cpu().float().numpy() + if normalize: + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 + else: + image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 + image_numpy = np.clip(image_numpy, 0, 255) + if image_numpy.shape[2] == 1: + image_numpy = image_numpy[:, :, 0] + return image_numpy.astype(imtype) + + +# Converts a one-hot tensor into a colorful label map +def tensor2label(label_tensor, n_label, imtype=np.uint8, tile=False): + if label_tensor.dim() == 4: + # transform each image in the batch + images_np = [] + for b in range(label_tensor.size(0)): + one_image = label_tensor[b] + one_image_np = tensor2label(one_image, n_label, imtype) + images_np.append(one_image_np.reshape(1, *one_image_np.shape)) + images_np = np.concatenate(images_np, axis=0) + # if tile: + # images_tiled = tile_images(images_np) + # return images_tiled + # else: + # images_np = images_np[0] + # return images_np + return images_np + + if label_tensor.dim() == 1: + return np.zeros((64, 64, 3), dtype=np.uint8) + if n_label == 0: + return tensor2im(label_tensor, imtype) + label_tensor = label_tensor.cpu().float() + if label_tensor.size()[0] > 1: + label_tensor = label_tensor.max(0, keepdim=True)[1] + label_tensor = Colorize(n_label)(label_tensor) + label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0)) + result = label_numpy.astype(imtype) + return result + + +def save_image(image_numpy, image_path, create_dir=False): + if create_dir: + os.makedirs(os.path.dirname(image_path), exist_ok=True) + if len(image_numpy.shape) == 2: + image_numpy = np.expand_dims(image_numpy, axis=2) + if image_numpy.shape[2] == 1: + image_numpy = np.repeat(image_numpy, 3, 2) + image_pil = Image.fromarray(image_numpy) + + # save to png + image_pil.save(image_path.replace(".jpg", ".png")) + + +def mkdirs(paths): + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + """ + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + """ + return [atoi(c) for c in re.split("(\d+)", text)] + + +def natural_sort(items): + items.sort(key=natural_keys) + + +def str2bool(v): + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + + +def find_class_in_module(target_cls_name, module): + target_cls_name = target_cls_name.replace("_", "").lower() + clslib = importlib.import_module(module) + cls = None + for name, clsobj in clslib.__dict__.items(): + if name.lower() == target_cls_name: + cls = clsobj + + if cls is None: + print( + "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" + % (module, target_cls_name) + ) + exit(0) + + return cls + + +def save_network(net, label, epoch, opt): + save_filename = "%s_net_%s.pth" % (epoch, label) + save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename) + torch.save(net.cpu().state_dict(), save_path) + if len(opt.gpu_ids) and torch.cuda.is_available(): + net.cuda() + + +def load_network(net, label, epoch, opt): + save_filename = "%s_net_%s.pth" % (epoch, label) + save_dir = os.path.join(opt.checkpoints_dir, opt.name) + save_path = os.path.join(save_dir, save_filename) + if os.path.exists(save_path): + weights = torch.load(save_path) + net.load_state_dict(weights) + return net + + +############################################################################### +# Code from +# https://github.com/ycszen/pytorch-seg/blob/master/transform.py +# Modified so it complies with the Citscape label map colors +############################################################################### +def uint82bin(n, count=8): + """returns the binary of integer n, count refers to amount of bits""" + return "".join([str((n >> y) & 1) for y in range(count - 1, -1, -1)]) + + +class Colorize(object): + def __init__(self, n=35): + self.cmap = labelcolormap(n) + self.cmap = torch.from_numpy(self.cmap[:n]) + + def __call__(self, gray_image): + size = gray_image.size() + color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) + + for label in range(0, len(self.cmap)): + mask = (label == gray_image[0]).cpu() + color_image[0][mask] = self.cmap[label][0] + color_image[1][mask] = self.cmap[label][1] + color_image[2][mask] = self.cmap[label][2] + + return color_image diff --git a/Face_Enhancement/util/visualizer.py b/Face_Enhancement/util/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..2cc519b52e9e15f5891ac3f4dcab620793794322 --- /dev/null +++ b/Face_Enhancement/util/visualizer.py @@ -0,0 +1,134 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os +import ntpath +import time +from . import util +import scipy.misc + +try: + from StringIO import StringIO # Python 2.7 +except ImportError: + from io import BytesIO # Python 3.x +import torchvision.utils as vutils +from tensorboardX import SummaryWriter +import torch +import numpy as np + + +class Visualizer: + def __init__(self, opt): + self.opt = opt + self.tf_log = opt.isTrain and opt.tf_log + + self.tensorboard_log = opt.tensorboard_log + + self.win_size = opt.display_winsize + self.name = opt.name + if self.tensorboard_log: + + if self.opt.isTrain: + self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, "logs") + if not os.path.exists(self.log_dir): + os.makedirs(self.log_dir) + self.writer = SummaryWriter(log_dir=self.log_dir) + else: + print("hi :)") + self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, opt.results_dir) + if not os.path.exists(self.log_dir): + os.makedirs(self.log_dir) + + if opt.isTrain: + self.log_name = os.path.join(opt.checkpoints_dir, opt.name, "loss_log.txt") + with open(self.log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write("================ Training Loss (%s) ================\n" % now) + + # |visuals|: dictionary of images to display or save + def display_current_results(self, visuals, epoch, step): + + all_tensor = [] + if self.tensorboard_log: + + for key, tensor in visuals.items(): + all_tensor.append((tensor.data.cpu() + 1) / 2) + + output = torch.cat(all_tensor, 0) + img_grid = vutils.make_grid(output, nrow=self.opt.batchSize, padding=0, normalize=False) + + if self.opt.isTrain: + self.writer.add_image("Face_SPADE/training_samples", img_grid, step) + else: + vutils.save_image( + output, + os.path.join(self.log_dir, str(step) + ".png"), + nrow=self.opt.batchSize, + padding=0, + normalize=False, + ) + + # errors: dictionary of error labels and values + def plot_current_errors(self, errors, step): + if self.tf_log: + for tag, value in errors.items(): + value = value.mean().float() + summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)]) + self.writer.add_summary(summary, step) + + if self.tensorboard_log: + + self.writer.add_scalar("Loss/GAN_Feat", errors["GAN_Feat"].mean().float(), step) + self.writer.add_scalar("Loss/VGG", errors["VGG"].mean().float(), step) + self.writer.add_scalars( + "Loss/GAN", + { + "G": errors["GAN"].mean().float(), + "D": (errors["D_Fake"].mean().float() + errors["D_real"].mean().float()) / 2, + }, + step, + ) + + # errors: same format as |errors| of plotCurrentErrors + def print_current_errors(self, epoch, i, errors, t): + message = "(epoch: %d, iters: %d, time: %.3f) " % (epoch, i, t) + for k, v in errors.items(): + v = v.mean().float() + message += "%s: %.3f " % (k, v) + + print(message) + with open(self.log_name, "a") as log_file: + log_file.write("%s\n" % message) + + def convert_visuals_to_numpy(self, visuals): + for key, t in visuals.items(): + tile = self.opt.batchSize > 8 + if "input_label" == key: + t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile) ## B*H*W*C 0-255 numpy + else: + t = util.tensor2im(t, tile=tile) + visuals[key] = t + return visuals + + # save image to the disk + def save_images(self, webpage, visuals, image_path): + visuals = self.convert_visuals_to_numpy(visuals) + + image_dir = webpage.get_image_dir() + short_path = ntpath.basename(image_path[0]) + name = os.path.splitext(short_path)[0] + + webpage.add_header(name) + ims = [] + txts = [] + links = [] + + for label, image_numpy in visuals.items(): + image_name = os.path.join(label, "%s.png" % (name)) + save_path = os.path.join(image_dir, image_name) + util.save_image(image_numpy, save_path, create_dir=True) + + ims.append(image_name) + txts.append(label) + links.append(image_name) + webpage.add_images(ims, txts, links, width=self.win_size) diff --git a/GUI.py b/GUI.py new file mode 100644 index 0000000000000000000000000000000000000000..c0ee59832e72d2802793a12beebc1032275bd19e --- /dev/null +++ b/GUI.py @@ -0,0 +1,217 @@ +import numpy as np +import cv2 +import PySimpleGUI as sg +import os.path +import argparse +import os +import sys +import shutil +from subprocess import call + +def modify(image_filename=None, cv2_frame=None): + + def run_cmd(command): + try: + call(command, shell=True) + except KeyboardInterrupt: + print("Process interrupted") + sys.exit(1) + + parser = argparse.ArgumentParser() + parser.add_argument("--input_folder", type=str, + default= image_filename, help="Test images") + parser.add_argument( + "--output_folder", + type=str, + default="./output", + help="Restored images, please use the absolute path", + ) + parser.add_argument("--GPU", type=str, default="-1", help="0,1,2") + parser.add_argument( + "--checkpoint_name", type=str, default="Setting_9_epoch_100", help="choose which checkpoint" + ) + parser.add_argument("--with_scratch",default="--with_scratch" ,action="store_true") + opts = parser.parse_args() + + gpu1 = opts.GPU + + # resolve relative paths before changing directory + opts.input_folder = os.path.abspath(opts.input_folder) + opts.output_folder = os.path.abspath(opts.output_folder) + if not os.path.exists(opts.output_folder): + os.makedirs(opts.output_folder) + + main_environment = os.getcwd() + + # Stage 1: Overall Quality Improve + print("Running Stage 1: Overall restoration") + os.chdir("./Global") + stage_1_input_dir = opts.input_folder + stage_1_output_dir = os.path.join( + opts.output_folder, "stage_1_restore_output") + if not os.path.exists(stage_1_output_dir): + os.makedirs(stage_1_output_dir) + + if not opts.with_scratch: + stage_1_command = ( + "python test.py --test_mode Full --Quality_restore --test_input " + + stage_1_input_dir + + " --outputs_dir " + + stage_1_output_dir + + " --gpu_ids " + + gpu1 + ) + run_cmd(stage_1_command) + else: + + mask_dir = os.path.join(stage_1_output_dir, "masks") + new_input = os.path.join(mask_dir, "input") + new_mask = os.path.join(mask_dir, "mask") + stage_1_command_1 = ( + "python detection.py --test_path " + + stage_1_input_dir + + " --output_dir " + + mask_dir + + " --input_size full_size" + + " --GPU " + + gpu1 + ) + stage_1_command_2 = ( + "python test.py --Scratch_and_Quality_restore --test_input " + + new_input + + " --test_mask " + + new_mask + + " --outputs_dir " + + stage_1_output_dir + + " --gpu_ids " + + gpu1 + ) + run_cmd(stage_1_command_1) + run_cmd(stage_1_command_2) + + # Solve the case when there is no face in the old photo + stage_1_results = os.path.join(stage_1_output_dir, "restored_image") + stage_4_output_dir = os.path.join(opts.output_folder, "final_output") + if not os.path.exists(stage_4_output_dir): + os.makedirs(stage_4_output_dir) + for x in os.listdir(stage_1_results): + img_dir = os.path.join(stage_1_results, x) + shutil.copy(img_dir, stage_4_output_dir) + + print("Finish Stage 1 ...") + print("\n") + + # Stage 2: Face Detection + + print("Running Stage 2: Face Detection") + os.chdir(".././Face_Detection") + stage_2_input_dir = os.path.join(stage_1_output_dir, "restored_image") + stage_2_output_dir = os.path.join( + opts.output_folder, "stage_2_detection_output") + if not os.path.exists(stage_2_output_dir): + os.makedirs(stage_2_output_dir) + stage_2_command = ( + "python detect_all_dlib.py --url " + stage_2_input_dir + + " --save_url " + stage_2_output_dir + ) + run_cmd(stage_2_command) + print("Finish Stage 2 ...") + print("\n") + + # Stage 3: Face Restore + print("Running Stage 3: Face Enhancement") + os.chdir(".././Face_Enhancement") + stage_3_input_mask = "./" + stage_3_input_face = stage_2_output_dir + stage_3_output_dir = os.path.join( + opts.output_folder, "stage_3_face_output") + if not os.path.exists(stage_3_output_dir): + os.makedirs(stage_3_output_dir) + stage_3_command = ( + "python test_face.py --old_face_folder " + + stage_3_input_face + + " --old_face_label_folder " + + stage_3_input_mask + + " --tensorboard_log --name " + + opts.checkpoint_name + + " --gpu_ids " + + gpu1 + + " --load_size 256 --label_nc 18 --no_instance --preprocess_mode resize --batchSize 4 --results_dir " + + stage_3_output_dir + + " --no_parsing_map" + ) + run_cmd(stage_3_command) + print("Finish Stage 3 ...") + print("\n") + + # Stage 4: Warp back + print("Running Stage 4: Blending") + os.chdir(".././Face_Detection") + stage_4_input_image_dir = os.path.join( + stage_1_output_dir, "restored_image") + stage_4_input_face_dir = os.path.join(stage_3_output_dir, "each_img") + stage_4_output_dir = os.path.join(opts.output_folder, "final_output") + if not os.path.exists(stage_4_output_dir): + os.makedirs(stage_4_output_dir) + stage_4_command = ( + "python align_warp_back_multiple_dlib.py --origin_url " + + stage_4_input_image_dir + + " --replace_url " + + stage_4_input_face_dir + + " --save_url " + + stage_4_output_dir + ) + run_cmd(stage_4_command) + print("Finish Stage 4 ...") + print("\n") + + print("All the processing is done. Please check the results.") + +# --------------------------------- The GUI --------------------------------- + +# First the window layout... + +images_col = [[sg.Text('Input file:'), sg.In(enable_events=True, key='-IN FILE-'), sg.FileBrowse()], + [sg.Button('Modify Photo', key='-MPHOTO-'), sg.Button('Exit')], + [sg.Image(filename='', key='-IN-'), sg.Image(filename='', key='-OUT-')],] +# ----- Full layout ----- +layout = [[sg.VSeperator(), sg.Column(images_col)]] + +# ----- Make the window ----- +window = sg.Window('Bringing-old-photos-back-to-life', layout, grab_anywhere=True) + +# ----- Run the Event Loop ----- +prev_filename = colorized = cap = None +while True: + event, values = window.read() + if event in (None, 'Exit'): + break + + elif event == '-MPHOTO-': + try: + n1 = filename.split("/")[-2] + n2 = filename.split("/")[-3] + n3 = filename.split("/")[-1] + filename= str(f"./{n2}/{n1}") + modify(filename) + + global f_image + f_image = f'./output/final_output/{n3}' + image = cv2.imread(f_image) + window['-OUT-'].update(data=cv2.imencode('.png', image)[1].tobytes()) + + except: + continue + + elif event == '-IN FILE-': # A single filename was chosen + filename = values['-IN FILE-'] + if filename != prev_filename: + prev_filename = filename + try: + image = cv2.imread(filename) + window['-IN-'].update(data=cv2.imencode('.png', image)[1].tobytes()) + except: + continue + +# ----- Exit program ----- +window.close() \ No newline at end of file diff --git a/Global/data/Create_Bigfile.py b/Global/data/Create_Bigfile.py new file mode 100644 index 0000000000000000000000000000000000000000..2df6ef3ceec4a80903410901fd9656d9707af84e --- /dev/null +++ b/Global/data/Create_Bigfile.py @@ -0,0 +1,63 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os +import struct +from PIL import Image + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def make_dataset(dir): + images = [] + assert os.path.isdir(dir), '%s is not a valid directory' % dir + + for root, _, fnames in sorted(os.walk(dir)): + for fname in fnames: + if is_image_file(fname): + #print(fname) + path = os.path.join(root, fname) + images.append(path) + + return images + +### Modify these 3 lines in your own environment +indir="/home/ziyuwan/workspace/data/temp_old" +target_folders=['VOC','Real_L_old','Real_RGB_old'] +out_dir ="/home/ziyuwan/workspace/data/temp_old" +### + +if os.path.exists(out_dir) is False: + os.makedirs(out_dir) + +# +for target_folder in target_folders: + curr_indir = os.path.join(indir, target_folder) + curr_out_file = os.path.join(os.path.join(out_dir, '%s.bigfile'%(target_folder))) + image_lists = make_dataset(curr_indir) + image_lists.sort() + with open(curr_out_file, 'wb') as wfid: + # write total image number + wfid.write(struct.pack('i', len(image_lists))) + for i, img_path in enumerate(image_lists): + # write file name first + img_name = os.path.basename(img_path) + img_name_bytes = img_name.encode('utf-8') + wfid.write(struct.pack('i', len(img_name_bytes))) + wfid.write(img_name_bytes) + # + # # write image data in + with open(img_path, 'rb') as img_fid: + img_bytes = img_fid.read() + wfid.write(struct.pack('i', len(img_bytes))) + wfid.write(img_bytes) + + if i % 1000 == 0: + print('write %d images done' % i) \ No newline at end of file diff --git a/Global/data/Load_Bigfile.py b/Global/data/Load_Bigfile.py new file mode 100644 index 0000000000000000000000000000000000000000..b34f1ece4d296f4e7e8ccb709d84a23c01ee5dd7 --- /dev/null +++ b/Global/data/Load_Bigfile.py @@ -0,0 +1,42 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import io +import os +import struct +from PIL import Image + +class BigFileMemoryLoader(object): + def __load_bigfile(self): + print('start load bigfile (%0.02f GB) into memory' % (os.path.getsize(self.file_path)/1024/1024/1024)) + with open(self.file_path, 'rb') as fid: + self.img_num = struct.unpack('i', fid.read(4))[0] + self.img_names = [] + self.img_bytes = [] + print('find total %d images' % self.img_num) + for i in range(self.img_num): + img_name_len = struct.unpack('i', fid.read(4))[0] + img_name = fid.read(img_name_len).decode('utf-8') + self.img_names.append(img_name) + img_bytes_len = struct.unpack('i', fid.read(4))[0] + self.img_bytes.append(fid.read(img_bytes_len)) + if i % 5000 == 0: + print('load %d images done' % i) + print('load all %d images done' % self.img_num) + + def __init__(self, file_path): + super(BigFileMemoryLoader, self).__init__() + self.file_path = file_path + self.__load_bigfile() + + def __getitem__(self, index): + try: + img = Image.open(io.BytesIO(self.img_bytes[index])).convert('RGB') + return self.img_names[index], img + except Exception: + print('Image read error for index %d: %s' % (index, self.img_names[index])) + return self.__getitem__((index+1)%self.img_num) + + + def __len__(self): + return self.img_num diff --git a/Global/data/__init__.py b/Global/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Global/data/base_data_loader.py b/Global/data/base_data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..05b12839de5bf604252c46a14c76c3d74db27868 --- /dev/null +++ b/Global/data/base_data_loader.py @@ -0,0 +1,16 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +class BaseDataLoader(): + def __init__(self): + pass + + def initialize(self, opt): + self.opt = opt + pass + + def load_data(): + return None + + + diff --git a/Global/data/base_dataset.py b/Global/data/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5f0ac562eacc926b606f70c9dea680021dec2edc --- /dev/null +++ b/Global/data/base_dataset.py @@ -0,0 +1,114 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch.utils.data as data +from PIL import Image +import torchvision.transforms as transforms +import numpy as np +import random + +class BaseDataset(data.Dataset): + def __init__(self): + super(BaseDataset, self).__init__() + + def name(self): + return 'BaseDataset' + + def initialize(self, opt): + pass + +def get_params(opt, size): + w, h = size + new_h = h + new_w = w + if opt.resize_or_crop == 'resize_and_crop': + new_h = new_w = opt.loadSize + + if opt.resize_or_crop == 'scale_width_and_crop': # we scale the shorter side into 256 + + if w 0.5 + return {'crop_pos': (x, y), 'flip': flip} + +def get_transform(opt, params, method=Image.BICUBIC, normalize=True): + transform_list = [] + if 'resize' in opt.resize_or_crop: + osize = [opt.loadSize, opt.loadSize] + transform_list.append(transforms.Scale(osize, method)) + elif 'scale_width' in opt.resize_or_crop: + # transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) ## Here , We want the shorter side to match 256, and Scale will finish it. + transform_list.append(transforms.Scale(256,method)) + + if 'crop' in opt.resize_or_crop: + if opt.isTrain: + transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize))) + else: + if opt.test_random_crop: + transform_list.append(transforms.RandomCrop(opt.fineSize)) + else: + transform_list.append(transforms.CenterCrop(opt.fineSize)) + + ## when testing, for ablation study, choose center_crop directly. + + + + if opt.resize_or_crop == 'none': + base = float(2 ** opt.n_downsample_global) + if opt.netG == 'local': + base *= (2 ** opt.n_local_enhancers) + transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) + + if opt.isTrain and not opt.no_flip: + transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) + + transform_list += [transforms.ToTensor()] + + if normalize: + transform_list += [transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + return transforms.Compose(transform_list) + +def normalize(): + return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + +def __make_power_2(img, base, method=Image.BICUBIC): + ow, oh = img.size + h = int(round(oh / base) * base) + w = int(round(ow / base) * base) + if (h == oh) and (w == ow): + return img + return img.resize((w, h), method) + +def __scale_width(img, target_width, method=Image.BICUBIC): + ow, oh = img.size + if (ow == target_width): + return img + w = target_width + h = int(target_width * oh / ow) + return img.resize((w, h), method) + +def __crop(img, pos, size): + ow, oh = img.size + x1, y1 = pos + tw = th = size + if (ow > tw or oh > th): + return img.crop((x1, y1, x1 + tw, y1 + th)) + return img + +def __flip(img, flip): + if flip: + return img.transpose(Image.FLIP_LEFT_RIGHT) + return img diff --git a/Global/data/custom_dataset_data_loader.py b/Global/data/custom_dataset_data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..04cc03203f216bb931eefb29b0c71c3dedaadae0 --- /dev/null +++ b/Global/data/custom_dataset_data_loader.py @@ -0,0 +1,41 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch.utils.data +import random +from data.base_data_loader import BaseDataLoader +from data import online_dataset_for_old_photos as dts_ray_bigfile + + +def CreateDataset(opt): + dataset = None + if opt.training_dataset=='domain_A' or opt.training_dataset=='domain_B': + dataset = dts_ray_bigfile.UnPairOldPhotos_SR() + if opt.training_dataset=='mapping': + if opt.random_hole: + dataset = dts_ray_bigfile.PairOldPhotos_with_hole() + else: + dataset = dts_ray_bigfile.PairOldPhotos() + print("dataset [%s] was created" % (dataset.name())) + dataset.initialize(opt) + return dataset + +class CustomDatasetDataLoader(BaseDataLoader): + def name(self): + return 'CustomDatasetDataLoader' + + def initialize(self, opt): + BaseDataLoader.initialize(self, opt) + self.dataset = CreateDataset(opt) + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=opt.batchSize, + shuffle=not opt.serial_batches, + num_workers=int(opt.nThreads), + drop_last=True) + + def load_data(self): + return self.dataloader + + def __len__(self): + return min(len(self.dataset), self.opt.max_dataset_size) diff --git a/Global/data/data_loader.py b/Global/data/data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..02ccaedcc08b2201dabcda4a80fd59c6cd8a8068 --- /dev/null +++ b/Global/data/data_loader.py @@ -0,0 +1,9 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +def CreateDataLoader(opt): + from data.custom_dataset_data_loader import CustomDatasetDataLoader + data_loader = CustomDatasetDataLoader() + print(data_loader.name()) + data_loader.initialize(opt) + return data_loader diff --git a/Global/data/image_folder.py b/Global/data/image_folder.py new file mode 100644 index 0000000000000000000000000000000000000000..8b1b9563ff3c3e18b8547c39cf708d583c68c29b --- /dev/null +++ b/Global/data/image_folder.py @@ -0,0 +1,62 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch.utils.data as data +from PIL import Image +import os + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def make_dataset(dir): + images = [] + assert os.path.isdir(dir), '%s is not a valid directory' % dir + + for root, _, fnames in sorted(os.walk(dir)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(root, fname) + images.append(path) + + return images + + +def default_loader(path): + return Image.open(path).convert('RGB') + + +class ImageFolder(data.Dataset): + + def __init__(self, root, transform=None, return_paths=False, + loader=default_loader): + imgs = make_dataset(root) + if len(imgs) == 0: + raise(RuntimeError("Found 0 images in: " + root + "\n" + "Supported image extensions are: " + + ",".join(IMG_EXTENSIONS))) + + self.root = root + self.imgs = imgs + self.transform = transform + self.return_paths = return_paths + self.loader = loader + + def __getitem__(self, index): + path = self.imgs[index] + img = self.loader(path) + if self.transform is not None: + img = self.transform(img) + if self.return_paths: + return img, path + else: + return img + + def __len__(self): + return len(self.imgs) diff --git a/Global/data/online_dataset_for_old_photos.py b/Global/data/online_dataset_for_old_photos.py new file mode 100644 index 0000000000000000000000000000000000000000..068410a93eb10d5f00e694fd890f8aaa069526a3 --- /dev/null +++ b/Global/data/online_dataset_for_old_photos.py @@ -0,0 +1,485 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os.path +import io +import zipfile +from data.base_dataset import BaseDataset, get_params, get_transform, normalize +from data.image_folder import make_dataset +from PIL import Image +import torchvision.transforms as transforms +import numpy as np +from data.Load_Bigfile import BigFileMemoryLoader +import random +import cv2 +from io import BytesIO + +def pil_to_np(img_PIL): + '''Converts image in PIL format to np.array. + + From W x H x C [0...255] to C x W x H [0..1] + ''' + ar = np.array(img_PIL) + + if len(ar.shape) == 3: + ar = ar.transpose(2, 0, 1) + else: + ar = ar[None, ...] + + return ar.astype(np.float32) / 255. + + +def np_to_pil(img_np): + '''Converts image in np.array format to PIL image. + + From C x W x H [0..1] to W x H x C [0...255] + ''' + ar = np.clip(img_np * 255, 0, 255).astype(np.uint8) + + if img_np.shape[0] == 1: + ar = ar[0] + else: + ar = ar.transpose(1, 2, 0) + + return Image.fromarray(ar) + +def synthesize_salt_pepper(image,amount,salt_vs_pepper): + + ## Give PIL, return the noisy PIL + + img_pil=pil_to_np(image) + + out = img_pil.copy() + p = amount + q = salt_vs_pepper + flipped = np.random.choice([True, False], size=img_pil.shape, + p=[p, 1 - p]) + salted = np.random.choice([True, False], size=img_pil.shape, + p=[q, 1 - q]) + peppered = ~salted + out[flipped & salted] = 1 + out[flipped & peppered] = 0. + noisy = np.clip(out, 0, 1).astype(np.float32) + + + return np_to_pil(noisy) + +def synthesize_gaussian(image,std_l,std_r): + + ## Give PIL, return the noisy PIL + + img_pil=pil_to_np(image) + + mean=0 + std=random.uniform(std_l/255.,std_r/255.) + gauss=np.random.normal(loc=mean,scale=std,size=img_pil.shape) + noisy=img_pil+gauss + noisy=np.clip(noisy,0,1).astype(np.float32) + + return np_to_pil(noisy) + +def synthesize_speckle(image,std_l,std_r): + + ## Give PIL, return the noisy PIL + + img_pil=pil_to_np(image) + + mean=0 + std=random.uniform(std_l/255.,std_r/255.) + gauss=np.random.normal(loc=mean,scale=std,size=img_pil.shape) + noisy=img_pil+gauss*img_pil + noisy=np.clip(noisy,0,1).astype(np.float32) + + return np_to_pil(noisy) + + +def synthesize_low_resolution(img): + w,h=img.size + + new_w=random.randint(int(w/2),w) + new_h=random.randint(int(h/2),h) + + img=img.resize((new_w,new_h),Image.BICUBIC) + + if random.uniform(0,1)<0.5: + img=img.resize((w,h),Image.NEAREST) + else: + img = img.resize((w, h), Image.BILINEAR) + + return img + + +def convertToJpeg(im,quality): + with BytesIO() as f: + im.save(f, format='JPEG',quality=quality) + f.seek(0) + return Image.open(f).convert('RGB') + + +def blur_image_v2(img): + + + x=np.array(img) + kernel_size_candidate=[(3,3),(5,5),(7,7)] + kernel_size=random.sample(kernel_size_candidate,1)[0] + std=random.uniform(1.,5.) + + #print("The gaussian kernel size: (%d,%d) std: %.2f"%(kernel_size[0],kernel_size[1],std)) + blur=cv2.GaussianBlur(x,kernel_size,std) + + return Image.fromarray(blur.astype(np.uint8)) + +def online_add_degradation_v2(img): + + task_id=np.random.permutation(4) + + for x in task_id: + if x==0 and random.uniform(0,1)<0.7: + img = blur_image_v2(img) + if x==1 and random.uniform(0,1)<0.7: + flag = random.choice([1, 2, 3]) + if flag == 1: + img = synthesize_gaussian(img, 5, 50) + if flag == 2: + img = synthesize_speckle(img, 5, 50) + if flag == 3: + img = synthesize_salt_pepper(img, random.uniform(0, 0.01), random.uniform(0.3, 0.8)) + if x==2 and random.uniform(0,1)<0.7: + img=synthesize_low_resolution(img) + + if x==3 and random.uniform(0,1)<0.7: + img=convertToJpeg(img,random.randint(40,100)) + + return img + + +def irregular_hole_synthesize(img,mask): + + img_np=np.array(img).astype('uint8') + mask_np=np.array(mask).astype('uint8') + mask_np=mask_np/255 + img_new=img_np*(1-mask_np)+mask_np*255 + + + hole_img=Image.fromarray(img_new.astype('uint8')).convert("RGB") + + return hole_img,mask.convert("L") + +def zero_mask(size): + x=np.zeros((size,size,3)).astype('uint8') + mask=Image.fromarray(x).convert("RGB") + return mask + + + +class UnPairOldPhotos_SR(BaseDataset): ## Synthetic + Real Old + def initialize(self, opt): + self.opt = opt + self.isImage = 'domainA' in opt.name + self.task = 'old_photo_restoration_training_vae' + self.dir_AB = opt.dataroot + if self.isImage: + + self.load_img_dir_L_old=os.path.join(self.dir_AB,"Real_L_old.bigfile") + self.load_img_dir_RGB_old=os.path.join(self.dir_AB,"Real_RGB_old.bigfile") + self.load_img_dir_clean=os.path.join(self.dir_AB,"VOC_RGB_JPEGImages.bigfile") + + self.loaded_imgs_L_old=BigFileMemoryLoader(self.load_img_dir_L_old) + self.loaded_imgs_RGB_old=BigFileMemoryLoader(self.load_img_dir_RGB_old) + self.loaded_imgs_clean=BigFileMemoryLoader(self.load_img_dir_clean) + + else: + # self.load_img_dir_clean=os.path.join(self.dir_AB,self.opt.test_dataset) + self.load_img_dir_clean=os.path.join(self.dir_AB,"VOC_RGB_JPEGImages.bigfile") + self.loaded_imgs_clean=BigFileMemoryLoader(self.load_img_dir_clean) + + #### + print("-------------Filter the imgs whose size <256 in VOC-------------") + self.filtered_imgs_clean=[] + for i in range(len(self.loaded_imgs_clean)): + img_name,img=self.loaded_imgs_clean[i] + h,w=img.size + if h<256 or w<256: + continue + self.filtered_imgs_clean.append((img_name,img)) + + print("--------Origin image num is [%d], filtered result is [%d]--------" % ( + len(self.loaded_imgs_clean), len(self.filtered_imgs_clean))) + ## Filter these images whose size is less than 256 + + # self.img_list=os.listdir(load_img_dir) + self.pid = os.getpid() + + def __getitem__(self, index): + + + is_real_old=0 + + sampled_dataset=None + degradation=None + if self.isImage: ## domain A , contains 2 kinds of data: synthetic + real_old + P=random.uniform(0,2) + if P>=0 and P<1: + if random.uniform(0,1)<0.5: + sampled_dataset=self.loaded_imgs_L_old + self.load_img_dir=self.load_img_dir_L_old + else: + sampled_dataset=self.loaded_imgs_RGB_old + self.load_img_dir=self.load_img_dir_RGB_old + is_real_old=1 + if P>=1 and P<2: + sampled_dataset=self.filtered_imgs_clean + self.load_img_dir=self.load_img_dir_clean + degradation=1 + else: + + sampled_dataset=self.filtered_imgs_clean + self.load_img_dir=self.load_img_dir_clean + + sampled_dataset_len=len(sampled_dataset) + + index=random.randint(0,sampled_dataset_len-1) + + img_name,img = sampled_dataset[index] + + if degradation is not None: + img=online_add_degradation_v2(img) + + path=os.path.join(self.load_img_dir,img_name) + + # AB = Image.open(path).convert('RGB') + # split AB image into A and B + + # apply the same transform to both A and B + + if random.uniform(0,1) <0.1: + img=img.convert("L") + img=img.convert("RGB") + ## Give a probability P, we convert the RGB image into L + + + A=img + w,h=A.size + if w<256 or h<256: + A=transforms.Scale(256,Image.BICUBIC)(A) + ## Since we want to only crop the images (256*256), for those old photos whose size is smaller than 256, we first resize them. + + transform_params = get_params(self.opt, A.size) + A_transform = get_transform(self.opt, transform_params) + + B_tensor = inst_tensor = feat_tensor = 0 + A_tensor = A_transform(A) + + + input_dict = {'label': A_tensor, 'inst': is_real_old, 'image': A_tensor, + 'feat': feat_tensor, 'path': path} + return input_dict + + def __len__(self): + return len(self.loaded_imgs_clean) ## actually, this is useless, since the selected index is just a random number + + def name(self): + return 'UnPairOldPhotos_SR' + + +class PairOldPhotos(BaseDataset): + def initialize(self, opt): + self.opt = opt + self.isImage = 'imagegan' in opt.name + self.task = 'old_photo_restoration_training_mapping' + self.dir_AB = opt.dataroot + if opt.isTrain: + self.load_img_dir_clean= os.path.join(self.dir_AB, "VOC_RGB_JPEGImages.bigfile") + self.loaded_imgs_clean = BigFileMemoryLoader(self.load_img_dir_clean) + + print("-------------Filter the imgs whose size <256 in VOC-------------") + self.filtered_imgs_clean = [] + for i in range(len(self.loaded_imgs_clean)): + img_name, img = self.loaded_imgs_clean[i] + h, w = img.size + if h < 256 or w < 256: + continue + self.filtered_imgs_clean.append((img_name, img)) + + print("--------Origin image num is [%d], filtered result is [%d]--------" % ( + len(self.loaded_imgs_clean), len(self.filtered_imgs_clean))) + + else: + self.load_img_dir=os.path.join(self.dir_AB,opt.test_dataset) + self.loaded_imgs=BigFileMemoryLoader(self.load_img_dir) + + self.pid = os.getpid() + + def __getitem__(self, index): + + + + if self.opt.isTrain: + img_name_clean,B = self.filtered_imgs_clean[index] + path = os.path.join(self.load_img_dir_clean, img_name_clean) + if self.opt.use_v2_degradation: + A=online_add_degradation_v2(B) + ### Remind: A is the input and B is corresponding GT + else: + + if self.opt.test_on_synthetic: + + img_name_B,B=self.loaded_imgs[index] + A=online_add_degradation_v2(B) + img_name_A=img_name_B + path = os.path.join(self.load_img_dir, img_name_A) + else: + img_name_A,A=self.loaded_imgs[index] + img_name_B,B=self.loaded_imgs[index] + path = os.path.join(self.load_img_dir, img_name_A) + + + if random.uniform(0,1)<0.1 and self.opt.isTrain: + A=A.convert("L") + B=B.convert("L") + A=A.convert("RGB") + B=B.convert("RGB") + ## In P, we convert the RGB into L + + + ##test on L + + # split AB image into A and B + # w, h = img.size + # w2 = int(w / 2) + # A = img.crop((0, 0, w2, h)) + # B = img.crop((w2, 0, w, h)) + w,h=A.size + if w<256 or h<256: + A=transforms.Scale(256,Image.BICUBIC)(A) + B=transforms.Scale(256, Image.BICUBIC)(B) + + # apply the same transform to both A and B + transform_params = get_params(self.opt, A.size) + A_transform = get_transform(self.opt, transform_params) + B_transform = get_transform(self.opt, transform_params) + + B_tensor = inst_tensor = feat_tensor = 0 + A_tensor = A_transform(A) + B_tensor = B_transform(B) + + input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor, + 'feat': feat_tensor, 'path': path} + return input_dict + + def __len__(self): + + if self.opt.isTrain: + return len(self.filtered_imgs_clean) + else: + return len(self.loaded_imgs) + + def name(self): + return 'PairOldPhotos' + + +class PairOldPhotos_with_hole(BaseDataset): + def initialize(self, opt): + self.opt = opt + self.isImage = 'imagegan' in opt.name + self.task = 'old_photo_restoration_training_mapping' + self.dir_AB = opt.dataroot + if opt.isTrain: + self.load_img_dir_clean= os.path.join(self.dir_AB, "VOC_RGB_JPEGImages.bigfile") + self.loaded_imgs_clean = BigFileMemoryLoader(self.load_img_dir_clean) + + print("-------------Filter the imgs whose size <256 in VOC-------------") + self.filtered_imgs_clean = [] + for i in range(len(self.loaded_imgs_clean)): + img_name, img = self.loaded_imgs_clean[i] + h, w = img.size + if h < 256 or w < 256: + continue + self.filtered_imgs_clean.append((img_name, img)) + + print("--------Origin image num is [%d], filtered result is [%d]--------" % ( + len(self.loaded_imgs_clean), len(self.filtered_imgs_clean))) + + else: + self.load_img_dir=os.path.join(self.dir_AB,opt.test_dataset) + self.loaded_imgs=BigFileMemoryLoader(self.load_img_dir) + + self.loaded_masks = BigFileMemoryLoader(opt.irregular_mask) + + self.pid = os.getpid() + + def __getitem__(self, index): + + + + if self.opt.isTrain: + img_name_clean,B = self.filtered_imgs_clean[index] + path = os.path.join(self.load_img_dir_clean, img_name_clean) + + + B=transforms.RandomCrop(256)(B) + A=online_add_degradation_v2(B) + ### Remind: A is the input and B is corresponding GT + + else: + img_name_A,A=self.loaded_imgs[index] + img_name_B,B=self.loaded_imgs[index] + path = os.path.join(self.load_img_dir, img_name_A) + + #A=A.resize((256,256)) + A=transforms.CenterCrop(256)(A) + B=A + + if random.uniform(0,1)<0.1 and self.opt.isTrain: + A=A.convert("L") + B=B.convert("L") + A=A.convert("RGB") + B=B.convert("RGB") + ## In P, we convert the RGB into L + + if self.opt.isTrain: + mask_name,mask=self.loaded_masks[random.randint(0,len(self.loaded_masks)-1)] + else: + mask_name, mask = self.loaded_masks[index%100] + mask = mask.resize((self.opt.loadSize, self.opt.loadSize), Image.NEAREST) + + if self.opt.random_hole and random.uniform(0,1)>0.5 and self.opt.isTrain: + mask=zero_mask(256) + + if self.opt.no_hole: + mask=zero_mask(256) + + + A,_=irregular_hole_synthesize(A,mask) + + if not self.opt.isTrain and self.opt.hole_image_no_mask: + mask=zero_mask(256) + + transform_params = get_params(self.opt, A.size) + A_transform = get_transform(self.opt, transform_params) + B_transform = get_transform(self.opt, transform_params) + + if transform_params['flip'] and self.opt.isTrain: + mask=mask.transpose(Image.FLIP_LEFT_RIGHT) + + mask_tensor = transforms.ToTensor()(mask) + + + B_tensor = inst_tensor = feat_tensor = 0 + A_tensor = A_transform(A) + B_tensor = B_transform(B) + + input_dict = {'label': A_tensor, 'inst': mask_tensor[:1], 'image': B_tensor, + 'feat': feat_tensor, 'path': path} + return input_dict + + def __len__(self): + + if self.opt.isTrain: + return len(self.filtered_imgs_clean) + + else: + return len(self.loaded_imgs) + + def name(self): + return 'PairOldPhotos_with_hole' \ No newline at end of file diff --git a/Global/detection.py b/Global/detection.py new file mode 100644 index 0000000000000000000000000000000000000000..f23494a3855ace6255261a56237f67f3c8dc7294 --- /dev/null +++ b/Global/detection.py @@ -0,0 +1,178 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +import gc +import json +import os +import time +import warnings + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision as tv +from PIL import Image, ImageFile + +from detection_models import networks +from detection_util.util import * + +warnings.filterwarnings("ignore", category=UserWarning) + +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +def data_transforms(img, full_size, method=Image.BICUBIC): + if full_size == "full_size": + ow, oh = img.size + h = int(round(oh / 16) * 16) + w = int(round(ow / 16) * 16) + if (h == oh) and (w == ow): + return img + return img.resize((w, h), method) + + elif full_size == "scale_256": + ow, oh = img.size + pw, ph = ow, oh + if ow < oh: + ow = 256 + oh = ph / pw * 256 + else: + oh = 256 + ow = pw / ph * 256 + + h = int(round(oh / 16) * 16) + w = int(round(ow / 16) * 16) + if (h == ph) and (w == pw): + return img + return img.resize((w, h), method) + + +def scale_tensor(img_tensor, default_scale=256): + _, _, w, h = img_tensor.shape + if w < h: + ow = default_scale + oh = h / w * default_scale + else: + oh = default_scale + ow = w / h * default_scale + + oh = int(round(oh / 16) * 16) + ow = int(round(ow / 16) * 16) + + return F.interpolate(img_tensor, [ow, oh], mode="bilinear") + + +def blend_mask(img, mask): + + np_img = np.array(img).astype("float") + + return Image.fromarray((np_img * (1 - mask) + mask * 255.0).astype("uint8")).convert("RGB") + + +def main(config): + print("initializing the dataloader") + + model = networks.UNet( + in_channels=1, + out_channels=1, + depth=4, + conv_num=2, + wf=6, + padding=True, + batch_norm=True, + up_mode="upsample", + with_tanh=False, + sync_bn=True, + antialiasing=True, + ) + + ## load model + checkpoint_path = os.path.join(os.path.dirname(__file__), "checkpoints/detection/FT_Epoch_latest.pt") + checkpoint = torch.load(checkpoint_path, map_location="cpu") + model.load_state_dict(checkpoint["model_state"]) + print("model weights loaded") + + if config.GPU >= 0: + model.to(config.GPU) + else: + model.cpu() + model.eval() + + ## dataloader and transformation + print("directory of testing image: " + config.test_path) + imagelist = os.listdir(config.test_path) + imagelist.sort() + total_iter = 0 + + P_matrix = {} + save_url = os.path.join(config.output_dir) + mkdir_if_not(save_url) + + input_dir = os.path.join(save_url, "input") + output_dir = os.path.join(save_url, "mask") + # blend_output_dir=os.path.join(save_url, 'blend_output') + mkdir_if_not(input_dir) + mkdir_if_not(output_dir) + # mkdir_if_not(blend_output_dir) + + idx = 0 + + results = [] + for image_name in imagelist: + + idx += 1 + + print("processing", image_name) + + scratch_file = os.path.join(config.test_path, image_name) + if not os.path.isfile(scratch_file): + print("Skipping non-file %s" % image_name) + continue + scratch_image = Image.open(scratch_file).convert("RGB") + w, h = scratch_image.size + + transformed_image_PIL = data_transforms(scratch_image, config.input_size) + scratch_image = transformed_image_PIL.convert("L") + scratch_image = tv.transforms.ToTensor()(scratch_image) + scratch_image = tv.transforms.Normalize([0.5], [0.5])(scratch_image) + scratch_image = torch.unsqueeze(scratch_image, 0) + _, _, ow, oh = scratch_image.shape + scratch_image_scale = scale_tensor(scratch_image) + + if config.GPU >= 0: + scratch_image_scale = scratch_image_scale.to(config.GPU) + else: + scratch_image_scale = scratch_image_scale.cpu() + with torch.no_grad(): + P = torch.sigmoid(model(scratch_image_scale)) + + P = P.data.cpu() + P = F.interpolate(P, [ow, oh], mode="nearest") + + tv.utils.save_image( + (P >= 0.4).float(), + os.path.join( + output_dir, + image_name[:-4] + ".png", + ), + nrow=1, + padding=0, + normalize=True, + ) + transformed_image_PIL.save(os.path.join(input_dir, image_name[:-4] + ".png")) + gc.collect() + torch.cuda.empty_cache() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # parser.add_argument('--checkpoint_name', type=str, default="FT_Epoch_latest.pt", help='Checkpoint Name') + + parser.add_argument("--GPU", type=int, default=0) + parser.add_argument("--test_path", type=str, default=".") + parser.add_argument("--output_dir", type=str, default=".") + parser.add_argument("--input_size", type=str, default="scale_256", help="resize_256|full_size|scale_256") + config = parser.parse_args() + + main(config) diff --git a/Global/detection_models/Synchronized-BatchNorm-PyTorch/LICENSE b/Global/detection_models/Synchronized-BatchNorm-PyTorch/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..4c39939e7e3aa940d405030335ec0e6ff2f2a1ee --- /dev/null +++ b/Global/detection_models/Synchronized-BatchNorm-PyTorch/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2018 Jiayuan MAO + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Global/detection_models/Synchronized-BatchNorm-PyTorch/README.md b/Global/detection_models/Synchronized-BatchNorm-PyTorch/README.md new file mode 100644 index 0000000000000000000000000000000000000000..779983436c9727dd0d6301a1c857f2360245b51d --- /dev/null +++ b/Global/detection_models/Synchronized-BatchNorm-PyTorch/README.md @@ -0,0 +1,118 @@ +# Synchronized-BatchNorm-PyTorch + +**IMPORTANT: Please read the "Implementation details and highlights" section before use.** + +Synchronized Batch Normalization implementation in PyTorch. + +This module differs from the built-in PyTorch BatchNorm as the mean and +standard-deviation are reduced across all devices during training. + +For example, when one uses `nn.DataParallel` to wrap the network during +training, PyTorch's implementation normalize the tensor on each device using +the statistics only on that device, which accelerated the computation and +is also easy to implement, but the statistics might be inaccurate. +Instead, in this synchronized version, the statistics will be computed +over all training samples distributed on multiple devices. + +Note that, for one-GPU or CPU-only case, this module behaves exactly same +as the built-in PyTorch implementation. + +This module is currently only a prototype version for research usages. As mentioned below, +it has its limitations and may even suffer from some design problems. If you have any +questions or suggestions, please feel free to +[open an issue](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues) or +[submit a pull request](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues). + +## Why Synchronized BatchNorm? + +Although the typical implementation of BatchNorm working on multiple devices (GPUs) +is fast (with no communication overhead), it inevitably reduces the size of batch size, +which potentially degenerates the performance. This is not a significant issue in some +standard vision tasks such as ImageNet classification (as the batch size per device +is usually large enough to obtain good statistics). However, it will hurt the performance +in some tasks that the batch size is usually very small (e.g., 1 per GPU). + +For example, the importance of synchronized batch normalization in object detection has been recently proved with a +an extensive analysis in the paper [MegDet: A Large Mini-Batch Object Detector](https://arxiv.org/abs/1711.07240). + +## Usage + +To use the Synchronized Batch Normalization, we add a data parallel replication callback. This introduces a slight +difference with typical usage of the `nn.DataParallel`. + +Use it with a provided, customized data parallel wrapper: + +```python +from sync_batchnorm import SynchronizedBatchNorm1d, DataParallelWithCallback + +sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) +sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) +``` + +Or, if you are using a customized data parallel module, you can use this library as a monkey patching. + +```python +from torch.nn import DataParallel # or your customized DataParallel module +from sync_batchnorm import SynchronizedBatchNorm1d, patch_replication_callback + +sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) +sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) +patch_replication_callback(sync_bn) # monkey-patching +``` + +You can use `convert_model` to convert your model to use Synchronized BatchNorm easily. + +```python +import torch.nn as nn +from torchvision import models +from sync_batchnorm import convert_model +# m is a standard pytorch model +m = models.resnet18(True) +m = nn.DataParallel(m) +# after convert, m is using SyncBN +m = convert_model(m) +``` + +See also `tests/test_sync_batchnorm.py` for numeric result comparison. + +## Implementation details and highlights + +If you are interested in how batch statistics are reduced and broadcasted among multiple devices, please take a look +at the code with detailed comments. Here we only emphasize some highlights of the implementation: + +- This implementation is in pure-python. No C++ extra extension libs. +- Easy to use as demonstrated above. +- It uses unbiased variance to update the moving average, and use `sqrt(max(var, eps))` instead of `sqrt(var + eps)`. +- The implementation requires that each module on different devices should invoke the `batchnorm` for exactly SAME +amount of times in each forward pass. For example, you can not only call `batchnorm` on GPU0 but not on GPU1. The `#i +(i = 1, 2, 3, ...)` calls of the `batchnorm` on each device will be viewed as a whole and the statistics will be reduced. +This is tricky but is a good way to handle PyTorch's dynamic computation graph. Although sounds complicated, this +will usually not be the issue for most of the models. + +## Known issues + +#### Runtime error on backward pass. + +Due to a [PyTorch Bug](https://github.com/pytorch/pytorch/issues/3883), using old PyTorch libraries will trigger an `RuntimeError` with messages like: + +``` +Assertion `pos >= 0 && pos < buffer.size()` failed. +``` + +This has already been solved in the newest PyTorch repo, which, unfortunately, has not been pushed to the official and anaconda binary release. Thus, you are required to build the PyTorch package from the source according to the + instructions [here](https://github.com/pytorch/pytorch#from-source). + +#### Numeric error. + +Because this library does not fuse the normalization and statistics operations in C++ (nor CUDA), it is less +numerically stable compared to the original PyTorch implementation. Detailed analysis can be found in +`tests/test_sync_batchnorm.py`. + +## Authors and License: + +Copyright (c) 2018-, [Jiayuan Mao](https://vccy.xyz). + +**Contributors**: [Tete Xiao](https://tetexiao.com), [DTennant](https://github.com/DTennant). + +Distributed under **MIT License** (See LICENSE) + diff --git a/Global/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/__init__.py b/Global/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6d9b36c74b1808b56ded68cf080a689db7e0ee4e --- /dev/null +++ b/Global/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# File : __init__.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +from .batchnorm import set_sbn_eps_mode +from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d +from .batchnorm import patch_sync_batchnorm, convert_model +from .replicate import DataParallelWithCallback, patch_replication_callback diff --git a/Global/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/batchnorm.py b/Global/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..bf8d7a7325b474771a11a137053971fd40426079 --- /dev/null +++ b/Global/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/batchnorm.py @@ -0,0 +1,412 @@ +# -*- coding: utf-8 -*- +# File : batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import collections +import contextlib + +import torch +import torch.nn.functional as F + +from torch.nn.modules.batchnorm import _BatchNorm + +try: + from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast +except ImportError: + ReduceAddCoalesced = Broadcast = None + +try: + from jactorch.parallel.comm import SyncMaster + from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback +except ImportError: + from .comm import SyncMaster + from .replicate import DataParallelWithCallback + +__all__ = [ + 'set_sbn_eps_mode', + 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d', + 'patch_sync_batchnorm', 'convert_model' +] + + +SBN_EPS_MODE = 'clamp' + + +def set_sbn_eps_mode(mode): + global SBN_EPS_MODE + assert mode in ('clamp', 'plus') + SBN_EPS_MODE = mode + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dimensions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) + + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): + assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.' + + super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, + track_running_stats=track_running_stats) + + if not self.track_running_stats: + import warnings + warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.') + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + def forward(self, input): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + + # Resize the input to (B, C, -1). + input_shape = input.size() + assert input.size(1) == self.num_features, 'Channel size mismatch: got {}, expect {}.'.format(input.size(1), self.num_features) + input = input.view(input.size(0), self.num_features, -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input ** 2) + + # Reduce-and-broadcast the statistics. + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + + # Compute the output. + if self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + + # Always using same "device order" makes the ReduceAdd operation faster. + # Thanks to:: Tete Xiao (http://tetexiao.com/) + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) + + return outputs + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + if hasattr(torch, 'no_grad'): + with torch.no_grad(): + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + else: + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + + if SBN_EPS_MODE == 'clamp': + return mean, bias_var.clamp(self.eps) ** -0.5 + elif SBN_EPS_MODE == 'plus': + return mean, (bias_var + self.eps) ** -0.5 + else: + raise ValueError('Unknown EPS mode: {}.'.format(SBN_EPS_MODE)) + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + + +@contextlib.contextmanager +def patch_sync_batchnorm(): + import torch.nn as nn + + backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d + + nn.BatchNorm1d = SynchronizedBatchNorm1d + nn.BatchNorm2d = SynchronizedBatchNorm2d + nn.BatchNorm3d = SynchronizedBatchNorm3d + + yield + + nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup + + +def convert_model(module): + """Traverse the input module and its child recursively + and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d + to SynchronizedBatchNorm*N*d + + Args: + module: the input module needs to be convert to SyncBN model + + Examples: + >>> import torch.nn as nn + >>> import torchvision + >>> # m is a standard pytorch model + >>> m = torchvision.models.resnet18(True) + >>> m = nn.DataParallel(m) + >>> # after convert, m is using SyncBN + >>> m = convert_model(m) + """ + if isinstance(module, torch.nn.DataParallel): + mod = module.module + mod = convert_model(mod) + mod = DataParallelWithCallback(mod, device_ids=module.device_ids) + return mod + + mod = module + for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d, + torch.nn.modules.batchnorm.BatchNorm2d, + torch.nn.modules.batchnorm.BatchNorm3d], + [SynchronizedBatchNorm1d, + SynchronizedBatchNorm2d, + SynchronizedBatchNorm3d]): + if isinstance(module, pth_module): + mod = sync_module(module.num_features, module.eps, module.momentum, module.affine) + mod.running_mean = module.running_mean + mod.running_var = module.running_var + if module.affine: + mod.weight.data = module.weight.data.clone().detach() + mod.bias.data = module.bias.data.clone().detach() + + for name, child in module.named_children(): + mod.add_module(name, convert_model(child)) + + return mod diff --git a/Global/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/batchnorm_reimpl.py b/Global/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/batchnorm_reimpl.py new file mode 100644 index 0000000000000000000000000000000000000000..18145c3353e13d482c492ae46df91a537669fca0 --- /dev/null +++ b/Global/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/batchnorm_reimpl.py @@ -0,0 +1,74 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# File : batchnorm_reimpl.py +# Author : acgtyrant +# Date : 11/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import torch +import torch.nn as nn +import torch.nn.init as init + +__all__ = ['BatchNorm2dReimpl'] + + +class BatchNorm2dReimpl(nn.Module): + """ + A re-implementation of batch normalization, used for testing the numerical + stability. + + Author: acgtyrant + See also: + https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 + """ + def __init__(self, num_features, eps=1e-5, momentum=0.1): + super().__init__() + + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.weight = nn.Parameter(torch.empty(num_features)) + self.bias = nn.Parameter(torch.empty(num_features)) + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_running_stats(self): + self.running_mean.zero_() + self.running_var.fill_(1) + + def reset_parameters(self): + self.reset_running_stats() + init.uniform_(self.weight) + init.zeros_(self.bias) + + def forward(self, input_): + batchsize, channels, height, width = input_.size() + numel = batchsize * height * width + input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) + sum_ = input_.sum(1) + sum_of_square = input_.pow(2).sum(1) + mean = sum_ / numel + sumvar = sum_of_square - sum_ * mean + + self.running_mean = ( + (1 - self.momentum) * self.running_mean + + self.momentum * mean.detach() + ) + unbias_var = sumvar / (numel - 1) + self.running_var = ( + (1 - self.momentum) * self.running_var + + self.momentum * unbias_var.detach() + ) + + bias_var = sumvar / numel + inv_std = 1 / (bias_var + self.eps).pow(0.5) + output = ( + (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * + self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) + + return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() + diff --git a/Global/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/comm.py b/Global/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..922f8c4a3adaa9b32fdcaef09583be03b0d7eb2b --- /dev/null +++ b/Global/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/comm.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +# File : comm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def __getstate__(self): + return {'master_callback': self._master_callback} + + def __setstate__(self, state): + self.__init__(state['master_callback']) + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/Global/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/replicate.py b/Global/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..b71c7b8ed51a1d6c55b1f753bdd8d90bad79bd06 --- /dev/null +++ b/Global/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/replicate.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# File : replicate.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import functools + +from torch.nn.parallel.data_parallel import DataParallel + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate diff --git a/Global/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/unittest.py b/Global/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/unittest.py new file mode 100644 index 0000000000000000000000000000000000000000..998223a0e0242dc4a5b2fcd74af79dc7232794da --- /dev/null +++ b/Global/detection_models/Synchronized-BatchNorm-PyTorch/sync_batchnorm/unittest.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# File : unittest.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import unittest +import torch + + +class TorchTestCase(unittest.TestCase): + def assertTensorClose(self, x, y): + adiff = float((x - y).abs().max()) + if (y == 0).all(): + rdiff = 'NaN' + else: + rdiff = float((adiff / y).abs().max()) + + message = ( + 'Tensor close check failed\n' + 'adiff={}\n' + 'rdiff={}\n' + ).format(adiff, rdiff) + self.assertTrue(torch.allclose(x, y, atol=1e-5, rtol=1e-3), message) + diff --git a/Global/detection_models/Synchronized-BatchNorm-PyTorch/tests/test_numeric_batchnorm.py b/Global/detection_models/Synchronized-BatchNorm-PyTorch/tests/test_numeric_batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..63661389782806ea2182c049448df5d05fc6d2f1 --- /dev/null +++ b/Global/detection_models/Synchronized-BatchNorm-PyTorch/tests/test_numeric_batchnorm.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# File : test_numeric_batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. + +import unittest + +import torch +import torch.nn as nn +from torch.autograd import Variable + +from sync_batchnorm.unittest import TorchTestCase + + +def handy_var(a, unbias=True): + n = a.size(0) + asum = a.sum(dim=0) + as_sum = (a ** 2).sum(dim=0) # a square sum + sumvar = as_sum - asum * asum / n + if unbias: + return sumvar / (n - 1) + else: + return sumvar / n + + +class NumericTestCase(TorchTestCase): + def testNumericBatchNorm(self): + a = torch.rand(16, 10) + bn = nn.BatchNorm1d(10, momentum=1, eps=1e-5, affine=False) + bn.train() + + a_var1 = Variable(a, requires_grad=True) + b_var1 = bn(a_var1) + loss1 = b_var1.sum() + loss1.backward() + + a_var2 = Variable(a, requires_grad=True) + a_mean2 = a_var2.mean(dim=0, keepdim=True) + a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5)) + # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5) + b_var2 = (a_var2 - a_mean2) / a_std2 + loss2 = b_var2.sum() + loss2.backward() + + self.assertTensorClose(bn.running_mean, a.mean(dim=0)) + self.assertTensorClose(bn.running_var, handy_var(a)) + self.assertTensorClose(a_var1.data, a_var2.data) + self.assertTensorClose(b_var1.data, b_var2.data) + self.assertTensorClose(a_var1.grad, a_var2.grad) + + +if __name__ == '__main__': + unittest.main() diff --git a/Global/detection_models/Synchronized-BatchNorm-PyTorch/tests/test_numeric_batchnorm_v2.py b/Global/detection_models/Synchronized-BatchNorm-PyTorch/tests/test_numeric_batchnorm_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4538ae3c50b4c457a9fa19bf22b5b1a7b666ee --- /dev/null +++ b/Global/detection_models/Synchronized-BatchNorm-PyTorch/tests/test_numeric_batchnorm_v2.py @@ -0,0 +1,62 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# File : test_numeric_batchnorm_v2.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 11/01/2018 +# +# Distributed under terms of the MIT license. + +""" +Test the numerical implementation of batch normalization. + +Author: acgtyrant. +See also: https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 +""" + +import unittest + +import torch +import torch.nn as nn +import torch.optim as optim + +from sync_batchnorm.unittest import TorchTestCase +from sync_batchnorm.batchnorm_reimpl import BatchNorm2dReimpl + + +class NumericTestCasev2(TorchTestCase): + def testNumericBatchNorm(self): + CHANNELS = 16 + batchnorm1 = nn.BatchNorm2d(CHANNELS, momentum=1) + optimizer1 = optim.SGD(batchnorm1.parameters(), lr=0.01) + + batchnorm2 = BatchNorm2dReimpl(CHANNELS, momentum=1) + batchnorm2.weight.data.copy_(batchnorm1.weight.data) + batchnorm2.bias.data.copy_(batchnorm1.bias.data) + optimizer2 = optim.SGD(batchnorm2.parameters(), lr=0.01) + + for _ in range(100): + input_ = torch.rand(16, CHANNELS, 16, 16) + + input1 = input_.clone().requires_grad_(True) + output1 = batchnorm1(input1) + output1.sum().backward() + optimizer1.step() + + input2 = input_.clone().requires_grad_(True) + output2 = batchnorm2(input2) + output2.sum().backward() + optimizer2.step() + + self.assertTensorClose(input1, input2) + self.assertTensorClose(output1, output2) + self.assertTensorClose(input1.grad, input2.grad) + self.assertTensorClose(batchnorm1.weight.grad, batchnorm2.weight.grad) + self.assertTensorClose(batchnorm1.bias.grad, batchnorm2.bias.grad) + self.assertTensorClose(batchnorm1.running_mean, batchnorm2.running_mean) + self.assertTensorClose(batchnorm2.running_mean, batchnorm2.running_mean) + + +if __name__ == '__main__': + unittest.main() + diff --git a/Global/detection_models/Synchronized-BatchNorm-PyTorch/tests/test_sync_batchnorm.py b/Global/detection_models/Synchronized-BatchNorm-PyTorch/tests/test_sync_batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..1f7b6c64c06fc26348489cd15669501a2098c82f --- /dev/null +++ b/Global/detection_models/Synchronized-BatchNorm-PyTorch/tests/test_sync_batchnorm.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- +# File : test_sync_batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. + +import unittest + +import torch +import torch.nn as nn +from torch.autograd import Variable + +from sync_batchnorm import set_sbn_eps_mode +from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback +from sync_batchnorm.unittest import TorchTestCase + +set_sbn_eps_mode('plus') + + +def handy_var(a, unbias=True): + n = a.size(0) + asum = a.sum(dim=0) + as_sum = (a ** 2).sum(dim=0) # a square sum + sumvar = as_sum - asum * asum / n + if unbias: + return sumvar / (n - 1) + else: + return sumvar / n + + +def _find_bn(module): + for m in module.modules(): + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)): + return m + + +class SyncTestCase(TorchTestCase): + def _syncParameters(self, bn1, bn2): + bn1.reset_parameters() + bn2.reset_parameters() + if bn1.affine and bn2.affine: + bn2.weight.data.copy_(bn1.weight.data) + bn2.bias.data.copy_(bn1.bias.data) + + def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False): + """Check the forward and backward for the customized batch normalization.""" + bn1.train(mode=is_train) + bn2.train(mode=is_train) + + if cuda: + input = input.cuda() + + self._syncParameters(_find_bn(bn1), _find_bn(bn2)) + + input1 = Variable(input, requires_grad=True) + output1 = bn1(input1) + output1.sum().backward() + input2 = Variable(input, requires_grad=True) + output2 = bn2(input2) + output2.sum().backward() + + self.assertTensorClose(input1.data, input2.data) + self.assertTensorClose(output1.data, output2.data) + self.assertTensorClose(input1.grad, input2.grad) + self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean) + self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var) + + def testSyncBatchNormNormalTrain(self): + bn = nn.BatchNorm1d(10) + sync_bn = SynchronizedBatchNorm1d(10) + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True) + + def testSyncBatchNormNormalEval(self): + bn = nn.BatchNorm1d(10) + sync_bn = SynchronizedBatchNorm1d(10) + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False) + + def testSyncBatchNormSyncTrain(self): + bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) + sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + + bn.cuda() + sync_bn.cuda() + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True) + + def testSyncBatchNormSyncEval(self): + bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) + sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + + bn.cuda() + sync_bn.cuda() + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True) + + def testSyncBatchNorm2DSyncTrain(self): + bn = nn.BatchNorm2d(10) + sync_bn = SynchronizedBatchNorm2d(10) + sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + + bn.cuda() + sync_bn.cuda() + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True) + + +if __name__ == '__main__': + unittest.main() diff --git a/Global/detection_models/__init__.py b/Global/detection_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Global/detection_models/antialiasing.py b/Global/detection_models/antialiasing.py new file mode 100644 index 0000000000000000000000000000000000000000..78da8ebdef518ffe597da1d03ffda09b89b22076 --- /dev/null +++ b/Global/detection_models/antialiasing.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.nn.parallel +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + + +class Downsample(nn.Module): + # https://github.com/adobe/antialiased-cnns + + def __init__(self, pad_type="reflect", filt_size=3, stride=2, channels=None, pad_off=0): + super(Downsample, self).__init__() + self.filt_size = filt_size + self.pad_off = pad_off + self.pad_sizes = [ + int(1.0 * (filt_size - 1) / 2), + int(np.ceil(1.0 * (filt_size - 1) / 2)), + int(1.0 * (filt_size - 1) / 2), + int(np.ceil(1.0 * (filt_size - 1) / 2)), + ] + self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes] + self.stride = stride + self.off = int((self.stride - 1) / 2.0) + self.channels = channels + + # print('Filter size [%i]'%filt_size) + if self.filt_size == 1: + a = np.array([1.0,]) + elif self.filt_size == 2: + a = np.array([1.0, 1.0]) + elif self.filt_size == 3: + a = np.array([1.0, 2.0, 1.0]) + elif self.filt_size == 4: + a = np.array([1.0, 3.0, 3.0, 1.0]) + elif self.filt_size == 5: + a = np.array([1.0, 4.0, 6.0, 4.0, 1.0]) + elif self.filt_size == 6: + a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0]) + elif self.filt_size == 7: + a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0]) + + filt = torch.Tensor(a[:, None] * a[None, :]) + filt = filt / torch.sum(filt) + self.register_buffer("filt", filt[None, None, :, :].repeat((self.channels, 1, 1, 1))) + + self.pad = get_pad_layer(pad_type)(self.pad_sizes) + + def forward(self, inp): + if self.filt_size == 1: + if self.pad_off == 0: + return inp[:, :, :: self.stride, :: self.stride] + else: + return self.pad(inp)[:, :, :: self.stride, :: self.stride] + else: + return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) + + +def get_pad_layer(pad_type): + if pad_type in ["refl", "reflect"]: + PadLayer = nn.ReflectionPad2d + elif pad_type in ["repl", "replicate"]: + PadLayer = nn.ReplicationPad2d + elif pad_type == "zero": + PadLayer = nn.ZeroPad2d + else: + print("Pad type [%s] not recognized" % pad_type) + return PadLayer diff --git a/Global/detection_models/networks.py b/Global/detection_models/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..d88bc5d5694db47220ccf70e97690de3224c2c60 --- /dev/null +++ b/Global/detection_models/networks.py @@ -0,0 +1,332 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from detection_models.sync_batchnorm import DataParallelWithCallback +from detection_models.antialiasing import Downsample + + +class UNet(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + depth=5, + conv_num=2, + wf=6, + padding=True, + batch_norm=True, + up_mode="upsample", + with_tanh=False, + sync_bn=True, + antialiasing=True, + ): + """ + Implementation of + U-Net: Convolutional Networks for Biomedical Image Segmentation + (Ronneberger et al., 2015) + https://arxiv.org/abs/1505.04597 + Using the default arguments will yield the exact version used + in the original paper + Args: + in_channels (int): number of input channels + out_channels (int): number of output channels + depth (int): depth of the network + wf (int): number of filters in the first layer is 2**wf + padding (bool): if True, apply padding such that the input shape + is the same as the output. + This may introduce artifacts + batch_norm (bool): Use BatchNorm after layers with an + activation function + up_mode (str): one of 'upconv' or 'upsample'. + 'upconv' will use transposed convolutions for + learned upsampling. + 'upsample' will use bilinear upsampling. + """ + super().__init__() + assert up_mode in ("upconv", "upsample") + self.padding = padding + self.depth = depth - 1 + prev_channels = in_channels + + self.first = nn.Sequential( + *[nn.ReflectionPad2d(3), nn.Conv2d(in_channels, 2 ** wf, kernel_size=7), nn.LeakyReLU(0.2, True)] + ) + prev_channels = 2 ** wf + + self.down_path = nn.ModuleList() + self.down_sample = nn.ModuleList() + for i in range(depth): + if antialiasing and depth > 0: + self.down_sample.append( + nn.Sequential( + *[ + nn.ReflectionPad2d(1), + nn.Conv2d(prev_channels, prev_channels, kernel_size=3, stride=1, padding=0), + nn.BatchNorm2d(prev_channels), + nn.LeakyReLU(0.2, True), + Downsample(channels=prev_channels, stride=2), + ] + ) + ) + else: + self.down_sample.append( + nn.Sequential( + *[ + nn.ReflectionPad2d(1), + nn.Conv2d(prev_channels, prev_channels, kernel_size=4, stride=2, padding=0), + nn.BatchNorm2d(prev_channels), + nn.LeakyReLU(0.2, True), + ] + ) + ) + self.down_path.append( + UNetConvBlock(conv_num, prev_channels, 2 ** (wf + i + 1), padding, batch_norm) + ) + prev_channels = 2 ** (wf + i + 1) + + self.up_path = nn.ModuleList() + for i in reversed(range(depth)): + self.up_path.append( + UNetUpBlock(conv_num, prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm) + ) + prev_channels = 2 ** (wf + i) + + if with_tanh: + self.last = nn.Sequential( + *[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3), nn.Tanh()] + ) + else: + self.last = nn.Sequential( + *[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3)] + ) + + if sync_bn: + self = DataParallelWithCallback(self) + + def forward(self, x): + x = self.first(x) + + blocks = [] + for i, down_block in enumerate(self.down_path): + blocks.append(x) + x = self.down_sample[i](x) + x = down_block(x) + + for i, up in enumerate(self.up_path): + x = up(x, blocks[-i - 1]) + + return self.last(x) + + +class UNetConvBlock(nn.Module): + def __init__(self, conv_num, in_size, out_size, padding, batch_norm): + super(UNetConvBlock, self).__init__() + block = [] + + for _ in range(conv_num): + block.append(nn.ReflectionPad2d(padding=int(padding))) + block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=0)) + if batch_norm: + block.append(nn.BatchNorm2d(out_size)) + block.append(nn.LeakyReLU(0.2, True)) + in_size = out_size + + self.block = nn.Sequential(*block) + + def forward(self, x): + out = self.block(x) + return out + + +class UNetUpBlock(nn.Module): + def __init__(self, conv_num, in_size, out_size, up_mode, padding, batch_norm): + super(UNetUpBlock, self).__init__() + if up_mode == "upconv": + self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2) + elif up_mode == "upsample": + self.up = nn.Sequential( + nn.Upsample(mode="bilinear", scale_factor=2, align_corners=False), + nn.ReflectionPad2d(1), + nn.Conv2d(in_size, out_size, kernel_size=3, padding=0), + ) + + self.conv_block = UNetConvBlock(conv_num, in_size, out_size, padding, batch_norm) + + def center_crop(self, layer, target_size): + _, _, layer_height, layer_width = layer.size() + diff_y = (layer_height - target_size[0]) // 2 + diff_x = (layer_width - target_size[1]) // 2 + return layer[:, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])] + + def forward(self, x, bridge): + up = self.up(x) + crop1 = self.center_crop(bridge, up.shape[2:]) + out = torch.cat([up, crop1], 1) + out = self.conv_block(out) + + return out + + +class UnetGenerator(nn.Module): + """Create a Unet-based generator""" + + def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_type="BN", use_dropout=False): + """Construct a Unet generator + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, + image of size 128x128 will become of size 1x1 # at the bottleneck + ngf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + We construct the U-Net from the innermost layer to the outermost layer. + It is a recursive process. + """ + super().__init__() + if norm_type == "BN": + norm_layer = nn.BatchNorm2d + elif norm_type == "IN": + norm_layer = nn.InstanceNorm2d + else: + raise NameError("Unknown norm layer") + + # construct unet structure + unet_block = UnetSkipConnectionBlock( + ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True + ) # add the innermost layer + for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters + unet_block = UnetSkipConnectionBlock( + ngf * 8, + ngf * 8, + input_nc=None, + submodule=unet_block, + norm_layer=norm_layer, + use_dropout=use_dropout, + ) + # gradually reduce the number of filters from ngf * 8 to ngf + unet_block = UnetSkipConnectionBlock( + ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer + ) + unet_block = UnetSkipConnectionBlock( + ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer + ) + unet_block = UnetSkipConnectionBlock( + ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer + ) + self.model = UnetSkipConnectionBlock( + output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer + ) # add the outermost layer + + def forward(self, input): + return self.model(input) + + +class UnetSkipConnectionBlock(nn.Module): + """Defines the Unet submodule with skip connection. + + -------------------identity---------------------- + |-- downsampling -- |submodule| -- upsampling --| + """ + + def __init__( + self, + outer_nc, + inner_nc, + input_nc=None, + submodule=None, + outermost=False, + innermost=False, + norm_layer=nn.BatchNorm2d, + use_dropout=False, + ): + """Construct a Unet submodule with skip connections. + Parameters: + outer_nc (int) -- the number of filters in the outer conv layer + inner_nc (int) -- the number of filters in the inner conv layer + input_nc (int) -- the number of channels in input images/features + submodule (UnetSkipConnectionBlock) -- previously defined submodules + outermost (bool) -- if this module is the outermost module + innermost (bool) -- if this module is the innermost module + norm_layer -- normalization layer + user_dropout (bool) -- if use dropout layers. + """ + super().__init__() + self.outermost = outermost + use_bias = norm_layer == nn.InstanceNorm2d + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) + downrelu = nn.LeakyReLU(0.2, True) + downnorm = norm_layer(inner_nc) + uprelu = nn.LeakyReLU(0.2, True) + upnorm = norm_layer(outer_nc) + + if outermost: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1) + down = [downconv] + up = [uprelu, upconv, nn.Tanh()] + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) + down = [downrelu, downconv] + up = [uprelu, upconv, upnorm] + model = down + up + else: + upconv = nn.ConvTranspose2d( + inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias + ) + down = [downrelu, downconv, downnorm] + up = [uprelu, upconv, upnorm] + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + def forward(self, x): + if self.outermost: + return self.model(x) + else: # add skip connections + return torch.cat([x, self.model(x)], 1) + + +# ============================================ +# Network testing +# ============================================ +if __name__ == "__main__": + from torchsummary import summary + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + model = UNet_two_decoders( + in_channels=3, + out_channels1=3, + out_channels2=1, + depth=4, + conv_num=1, + wf=6, + padding=True, + batch_norm=True, + up_mode="upsample", + with_tanh=False, + ) + model.to(device) + + model_pix2pix = UnetGenerator(3, 3, 5, ngf=64, norm_type="BN", use_dropout=False) + model_pix2pix.to(device) + + print("customized unet:") + summary(model, (3, 256, 256)) + + print("cyclegan unet:") + summary(model_pix2pix, (3, 256, 256)) + + x = torch.zeros(1, 3, 256, 256).requires_grad_(True).cuda() + g = make_dot(model(x)) + g.render("models/Digraph.gv", view=False) + diff --git a/Global/detection_models/sync_batchnorm/__init__.py b/Global/detection_models/sync_batchnorm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6d9b36c74b1808b56ded68cf080a689db7e0ee4e --- /dev/null +++ b/Global/detection_models/sync_batchnorm/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# File : __init__.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +from .batchnorm import set_sbn_eps_mode +from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d +from .batchnorm import patch_sync_batchnorm, convert_model +from .replicate import DataParallelWithCallback, patch_replication_callback diff --git a/Global/detection_models/sync_batchnorm/batchnorm.py b/Global/detection_models/sync_batchnorm/batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..bf8d7a7325b474771a11a137053971fd40426079 --- /dev/null +++ b/Global/detection_models/sync_batchnorm/batchnorm.py @@ -0,0 +1,412 @@ +# -*- coding: utf-8 -*- +# File : batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import collections +import contextlib + +import torch +import torch.nn.functional as F + +from torch.nn.modules.batchnorm import _BatchNorm + +try: + from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast +except ImportError: + ReduceAddCoalesced = Broadcast = None + +try: + from jactorch.parallel.comm import SyncMaster + from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback +except ImportError: + from .comm import SyncMaster + from .replicate import DataParallelWithCallback + +__all__ = [ + 'set_sbn_eps_mode', + 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d', + 'patch_sync_batchnorm', 'convert_model' +] + + +SBN_EPS_MODE = 'clamp' + + +def set_sbn_eps_mode(mode): + global SBN_EPS_MODE + assert mode in ('clamp', 'plus') + SBN_EPS_MODE = mode + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dimensions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) + + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): + assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.' + + super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, + track_running_stats=track_running_stats) + + if not self.track_running_stats: + import warnings + warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.') + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + def forward(self, input): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + + # Resize the input to (B, C, -1). + input_shape = input.size() + assert input.size(1) == self.num_features, 'Channel size mismatch: got {}, expect {}.'.format(input.size(1), self.num_features) + input = input.view(input.size(0), self.num_features, -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input ** 2) + + # Reduce-and-broadcast the statistics. + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + + # Compute the output. + if self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + + # Always using same "device order" makes the ReduceAdd operation faster. + # Thanks to:: Tete Xiao (http://tetexiao.com/) + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) + + return outputs + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + if hasattr(torch, 'no_grad'): + with torch.no_grad(): + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + else: + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + + if SBN_EPS_MODE == 'clamp': + return mean, bias_var.clamp(self.eps) ** -0.5 + elif SBN_EPS_MODE == 'plus': + return mean, (bias_var + self.eps) ** -0.5 + else: + raise ValueError('Unknown EPS mode: {}.'.format(SBN_EPS_MODE)) + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + + +@contextlib.contextmanager +def patch_sync_batchnorm(): + import torch.nn as nn + + backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d + + nn.BatchNorm1d = SynchronizedBatchNorm1d + nn.BatchNorm2d = SynchronizedBatchNorm2d + nn.BatchNorm3d = SynchronizedBatchNorm3d + + yield + + nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup + + +def convert_model(module): + """Traverse the input module and its child recursively + and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d + to SynchronizedBatchNorm*N*d + + Args: + module: the input module needs to be convert to SyncBN model + + Examples: + >>> import torch.nn as nn + >>> import torchvision + >>> # m is a standard pytorch model + >>> m = torchvision.models.resnet18(True) + >>> m = nn.DataParallel(m) + >>> # after convert, m is using SyncBN + >>> m = convert_model(m) + """ + if isinstance(module, torch.nn.DataParallel): + mod = module.module + mod = convert_model(mod) + mod = DataParallelWithCallback(mod, device_ids=module.device_ids) + return mod + + mod = module + for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d, + torch.nn.modules.batchnorm.BatchNorm2d, + torch.nn.modules.batchnorm.BatchNorm3d], + [SynchronizedBatchNorm1d, + SynchronizedBatchNorm2d, + SynchronizedBatchNorm3d]): + if isinstance(module, pth_module): + mod = sync_module(module.num_features, module.eps, module.momentum, module.affine) + mod.running_mean = module.running_mean + mod.running_var = module.running_var + if module.affine: + mod.weight.data = module.weight.data.clone().detach() + mod.bias.data = module.bias.data.clone().detach() + + for name, child in module.named_children(): + mod.add_module(name, convert_model(child)) + + return mod diff --git a/Global/detection_models/sync_batchnorm/batchnorm_reimpl.py b/Global/detection_models/sync_batchnorm/batchnorm_reimpl.py new file mode 100644 index 0000000000000000000000000000000000000000..18145c3353e13d482c492ae46df91a537669fca0 --- /dev/null +++ b/Global/detection_models/sync_batchnorm/batchnorm_reimpl.py @@ -0,0 +1,74 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# File : batchnorm_reimpl.py +# Author : acgtyrant +# Date : 11/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import torch +import torch.nn as nn +import torch.nn.init as init + +__all__ = ['BatchNorm2dReimpl'] + + +class BatchNorm2dReimpl(nn.Module): + """ + A re-implementation of batch normalization, used for testing the numerical + stability. + + Author: acgtyrant + See also: + https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 + """ + def __init__(self, num_features, eps=1e-5, momentum=0.1): + super().__init__() + + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.weight = nn.Parameter(torch.empty(num_features)) + self.bias = nn.Parameter(torch.empty(num_features)) + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_running_stats(self): + self.running_mean.zero_() + self.running_var.fill_(1) + + def reset_parameters(self): + self.reset_running_stats() + init.uniform_(self.weight) + init.zeros_(self.bias) + + def forward(self, input_): + batchsize, channels, height, width = input_.size() + numel = batchsize * height * width + input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) + sum_ = input_.sum(1) + sum_of_square = input_.pow(2).sum(1) + mean = sum_ / numel + sumvar = sum_of_square - sum_ * mean + + self.running_mean = ( + (1 - self.momentum) * self.running_mean + + self.momentum * mean.detach() + ) + unbias_var = sumvar / (numel - 1) + self.running_var = ( + (1 - self.momentum) * self.running_var + + self.momentum * unbias_var.detach() + ) + + bias_var = sumvar / numel + inv_std = 1 / (bias_var + self.eps).pow(0.5) + output = ( + (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * + self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) + + return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() + diff --git a/Global/detection_models/sync_batchnorm/comm.py b/Global/detection_models/sync_batchnorm/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..922f8c4a3adaa9b32fdcaef09583be03b0d7eb2b --- /dev/null +++ b/Global/detection_models/sync_batchnorm/comm.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +# File : comm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def __getstate__(self): + return {'master_callback': self._master_callback} + + def __setstate__(self, state): + self.__init__(state['master_callback']) + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/Global/detection_models/sync_batchnorm/replicate.py b/Global/detection_models/sync_batchnorm/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..b71c7b8ed51a1d6c55b1f753bdd8d90bad79bd06 --- /dev/null +++ b/Global/detection_models/sync_batchnorm/replicate.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# File : replicate.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import functools + +from torch.nn.parallel.data_parallel import DataParallel + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate diff --git a/Global/detection_models/sync_batchnorm/unittest.py b/Global/detection_models/sync_batchnorm/unittest.py new file mode 100644 index 0000000000000000000000000000000000000000..998223a0e0242dc4a5b2fcd74af79dc7232794da --- /dev/null +++ b/Global/detection_models/sync_batchnorm/unittest.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# File : unittest.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import unittest +import torch + + +class TorchTestCase(unittest.TestCase): + def assertTensorClose(self, x, y): + adiff = float((x - y).abs().max()) + if (y == 0).all(): + rdiff = 'NaN' + else: + rdiff = float((adiff / y).abs().max()) + + message = ( + 'Tensor close check failed\n' + 'adiff={}\n' + 'rdiff={}\n' + ).format(adiff, rdiff) + self.assertTrue(torch.allclose(x, y, atol=1e-5, rtol=1e-3), message) + diff --git a/Global/detection_util/util.py b/Global/detection_util/util.py new file mode 100644 index 0000000000000000000000000000000000000000..be10881fc4077015d12a28f5ae5b0a04021ad627 --- /dev/null +++ b/Global/detection_util/util.py @@ -0,0 +1,245 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os +import sys +import time +import shutil +import platform +import numpy as np +from datetime import datetime + +import torch +import torchvision as tv +import torch.backends.cudnn as cudnn + +# from torch.utils.tensorboard import SummaryWriter + +import yaml +import matplotlib.pyplot as plt +from easydict import EasyDict as edict +import torchvision.utils as vutils + + +##### option parsing ###### +def print_options(config_dict): + print("------------ Options -------------") + for k, v in sorted(config_dict.items()): + print("%s: %s" % (str(k), str(v))) + print("-------------- End ----------------") + + +def save_options(config_dict): + from time import gmtime, strftime + + file_dir = os.path.join(config_dict["checkpoint_dir"], config_dict["name"]) + mkdir_if_not(file_dir) + file_name = os.path.join(file_dir, "opt.txt") + with open(file_name, "wt") as opt_file: + opt_file.write(os.path.basename(sys.argv[0]) + " " + strftime("%Y-%m-%d %H:%M:%S", gmtime()) + "\n") + opt_file.write("------------ Options -------------\n") + for k, v in sorted(config_dict.items()): + opt_file.write("%s: %s\n" % (str(k), str(v))) + opt_file.write("-------------- End ----------------\n") + + +def config_parse(config_file, options, save=True): + with open(config_file, "r") as stream: + config_dict = yaml.safe_load(stream) + config = edict(config_dict) + + for option_key, option_value in vars(options).items(): + config_dict[option_key] = option_value + config[option_key] = option_value + + if config.debug_mode: + config_dict["num_workers"] = 0 + config.num_workers = 0 + config.batch_size = 2 + if isinstance(config.gpu_ids, str): + config.gpu_ids = [int(x) for x in config.gpu_ids.split(",")][0] + + print_options(config_dict) + if save: + save_options(config_dict) + + return config + + +###### utility ###### +def to_np(x): + return x.cpu().numpy() + + +def prepare_device(use_gpu, gpu_ids): + if use_gpu: + cudnn.benchmark = True + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + if isinstance(gpu_ids, str): + gpu_ids = [int(x) for x in gpu_ids.split(",")] + torch.cuda.set_device(gpu_ids[0]) + device = torch.device("cuda:" + str(gpu_ids[0])) + else: + torch.cuda.set_device(gpu_ids) + device = torch.device("cuda:" + str(gpu_ids)) + print("running on GPU {}".format(gpu_ids)) + else: + device = torch.device("cpu") + print("running on CPU") + + return device + + +###### file system ###### +def get_dir_size(start_path="."): + total_size = 0 + for dirpath, dirnames, filenames in os.walk(start_path): + for f in filenames: + fp = os.path.join(dirpath, f) + total_size += os.path.getsize(fp) + return total_size + + +def mkdir_if_not(dir_path): + if not os.path.exists(dir_path): + os.makedirs(dir_path) + + +##### System related ###### +class Timer: + def __init__(self, msg): + self.msg = msg + self.start_time = None + + def __enter__(self): + self.start_time = time.time() + + def __exit__(self, exc_type, exc_value, exc_tb): + elapse = time.time() - self.start_time + print(self.msg % elapse) + + +###### interactive ###### +def get_size(start_path="."): + total_size = 0 + for dirpath, dirnames, filenames in os.walk(start_path): + for f in filenames: + fp = os.path.join(dirpath, f) + total_size += os.path.getsize(fp) + return total_size + + +def clean_tensorboard(directory): + tensorboard_list = os.listdir(directory) + SIZE_THRESH = 100000 + for tensorboard in tensorboard_list: + tensorboard = os.path.join(directory, tensorboard) + if get_size(tensorboard) < SIZE_THRESH: + print("deleting the empty tensorboard: ", tensorboard) + # + if os.path.isdir(tensorboard): + shutil.rmtree(tensorboard) + else: + os.remove(tensorboard) + + +def prepare_tensorboard(config, experiment_name=datetime.now().strftime("%Y-%m-%d %H-%M-%S")): + tensorboard_directory = os.path.join(config.checkpoint_dir, config.name, "tensorboard_logs") + mkdir_if_not(tensorboard_directory) + clean_tensorboard(tensorboard_directory) + tb_writer = SummaryWriter(os.path.join(tensorboard_directory, experiment_name), flush_secs=10) + + # try: + # shutil.copy('outputs/opt.txt', tensorboard_directory) + # except: + # print('cannot find file opt.txt') + return tb_writer + + +def tb_loss_logger(tb_writer, iter_index, loss_logger): + for tag, value in loss_logger.items(): + tb_writer.add_scalar(tag, scalar_value=value.item(), global_step=iter_index) + + +def tb_image_logger(tb_writer, iter_index, images_info, config): + ### Save and write the output into the tensorboard + tb_logger_path = os.path.join(config.output_dir, config.name, config.train_mode) + mkdir_if_not(tb_logger_path) + for tag, image in images_info.items(): + if tag == "test_image_prediction" or tag == "image_prediction": + continue + image = tv.utils.make_grid(image.cpu()) + image = torch.clamp(image, 0, 1) + tb_writer.add_image(tag, img_tensor=image, global_step=iter_index) + tv.transforms.functional.to_pil_image(image).save( + os.path.join(tb_logger_path, "{:06d}_{}.jpg".format(iter_index, tag)) + ) + + +def tb_image_logger_test(epoch, iter, images_info, config): + + url = os.path.join(config.output_dir, config.name, config.train_mode, "val_" + str(epoch)) + if not os.path.exists(url): + os.makedirs(url) + scratch_img = images_info["test_scratch_image"].data.cpu() + if config.norm_input: + scratch_img = (scratch_img + 1.0) / 2.0 + scratch_img = torch.clamp(scratch_img, 0, 1) + gt_mask = images_info["test_mask_image"].data.cpu() + predict_mask = images_info["test_scratch_prediction"].data.cpu() + + predict_hard_mask = (predict_mask.data.cpu() >= 0.5).float() + + imgs = torch.cat((scratch_img, predict_hard_mask, gt_mask), 0) + img_grid = vutils.save_image( + imgs, os.path.join(url, str(iter) + ".jpg"), nrow=len(scratch_img), padding=0, normalize=True + ) + + +def imshow(input_image, title=None, to_numpy=False): + inp = input_image + if to_numpy or type(input_image) is torch.Tensor: + inp = input_image.numpy() + + fig = plt.figure() + if inp.ndim == 2: + fig = plt.imshow(inp, cmap="gray", clim=[0, 255]) + else: + fig = plt.imshow(np.transpose(inp, [1, 2, 0]).astype(np.uint8)) + plt.axis("off") + fig.axes.get_xaxis().set_visible(False) + fig.axes.get_yaxis().set_visible(False) + plt.title(title) + + +###### vgg preprocessing ###### +def vgg_preprocess(tensor): + # input is RGB tensor which ranges in [0,1] + # output is BGR tensor which ranges in [0,255] + tensor_bgr = torch.cat((tensor[:, 2:3, :, :], tensor[:, 1:2, :, :], tensor[:, 0:1, :, :]), dim=1) + # tensor_bgr = tensor[:, [2, 1, 0], ...] + tensor_bgr_ml = tensor_bgr - torch.Tensor([0.40760392, 0.45795686, 0.48501961]).type_as(tensor_bgr).view( + 1, 3, 1, 1 + ) + tensor_rst = tensor_bgr_ml * 255 + return tensor_rst + + +def torch_vgg_preprocess(tensor): + # pytorch version normalization + # note that both input and output are RGB tensors; + # input and output ranges in [0,1] + # normalize the tensor with mean and variance + tensor_mc = tensor - torch.Tensor([0.485, 0.456, 0.406]).type_as(tensor).view(1, 3, 1, 1) + tensor_mc_norm = tensor_mc / torch.Tensor([0.229, 0.224, 0.225]).type_as(tensor_mc).view(1, 3, 1, 1) + return tensor_mc_norm + + +def network_gradient(net, gradient_on=True): + if gradient_on: + for param in net.parameters(): + param.requires_grad = True + else: + for param in net.parameters(): + param.requires_grad = False + return net diff --git a/Global/models/NonLocal_feature_mapping_model.py b/Global/models/NonLocal_feature_mapping_model.py new file mode 100755 index 0000000000000000000000000000000000000000..1b9bb1031d8c1fe399fb4fa61e875027a6cfc4a5 --- /dev/null +++ b/Global/models/NonLocal_feature_mapping_model.py @@ -0,0 +1,202 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import os +import functools +from torch.autograd import Variable +from util.image_pool import ImagePool +from .base_model import BaseModel +from . import networks +import math + + +class Mapping_Model_with_mask(nn.Module): + def __init__(self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None): + super(Mapping_Model_with_mask, self).__init__() + + norm_layer = networks.get_norm_layer(norm_type=norm) + activation = nn.ReLU(True) + model = [] + + tmp_nc = 64 + n_up = 4 + + for i in range(n_up): + ic = min(tmp_nc * (2 ** i), mc) + oc = min(tmp_nc * (2 ** (i + 1)), mc) + model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] + + self.before_NL = nn.Sequential(*model) + + if opt.NL_res: + self.NL = networks.NonLocalBlock2D_with_mask_Res( + mc, + mc, + opt.NL_fusion_method, + opt.correlation_renormalize, + opt.softmax_temperature, + opt.use_self, + opt.cosin_similarity, + ) + print("You are using NL + Res") + + model = [] + for i in range(n_blocks): + model += [ + networks.ResnetBlock( + mc, + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + opt=opt, + dilation=opt.mapping_net_dilation, + ) + ] + + for i in range(n_up - 1): + ic = min(64 * (2 ** (4 - i)), mc) + oc = min(64 * (2 ** (3 - i)), mc) + model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] + model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)] + if opt.feat_dim > 0 and opt.feat_dim < 64: + model += [norm_layer(tmp_nc), activation, nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1)] + # model += [nn.Conv2d(64, 1, 1, 1, 0)] + self.after_NL = nn.Sequential(*model) + + + def forward(self, input, mask): + x1 = self.before_NL(input) + del input + x2 = self.NL(x1, mask) + del x1, mask + x3 = self.after_NL(x2) + del x2 + + return x3 + +class Mapping_Model_with_mask_2(nn.Module): ## Multi-Scale Patch Attention + def __init__(self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None): + super(Mapping_Model_with_mask_2, self).__init__() + + norm_layer = networks.get_norm_layer(norm_type=norm) + activation = nn.ReLU(True) + model = [] + + tmp_nc = 64 + n_up = 4 + + for i in range(n_up): + ic = min(tmp_nc * (2 ** i), mc) + oc = min(tmp_nc * (2 ** (i + 1)), mc) + model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] + + for i in range(2): + model += [ + networks.ResnetBlock( + mc, + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + opt=opt, + dilation=opt.mapping_net_dilation, + ) + ] + + print("Mapping: You are using multi-scale patch attention, conv combine + mask input") + + self.before_NL = nn.Sequential(*model) + + if opt.mapping_exp==1: + self.NL_scale_1=networks.Patch_Attention_4(mc,mc,8) + + model = [] + for i in range(2): + model += [ + networks.ResnetBlock( + mc, + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + opt=opt, + dilation=opt.mapping_net_dilation, + ) + ] + + self.res_block_1 = nn.Sequential(*model) + + if opt.mapping_exp==1: + self.NL_scale_2=networks.Patch_Attention_4(mc,mc,4) + + model = [] + for i in range(2): + model += [ + networks.ResnetBlock( + mc, + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + opt=opt, + dilation=opt.mapping_net_dilation, + ) + ] + + self.res_block_2 = nn.Sequential(*model) + + if opt.mapping_exp==1: + self.NL_scale_3=networks.Patch_Attention_4(mc,mc,2) + # self.NL_scale_3=networks.Patch_Attention_2(mc,mc,2) + + model = [] + for i in range(2): + model += [ + networks.ResnetBlock( + mc, + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + opt=opt, + dilation=opt.mapping_net_dilation, + ) + ] + + for i in range(n_up - 1): + ic = min(64 * (2 ** (4 - i)), mc) + oc = min(64 * (2 ** (3 - i)), mc) + model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] + model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)] + if opt.feat_dim > 0 and opt.feat_dim < 64: + model += [norm_layer(tmp_nc), activation, nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1)] + # model += [nn.Conv2d(64, 1, 1, 1, 0)] + self.after_NL = nn.Sequential(*model) + + + def forward(self, input, mask): + x1 = self.before_NL(input) + x2 = self.NL_scale_1(x1,mask) + x3 = self.res_block_1(x2) + x4 = self.NL_scale_2(x3,mask) + x5 = self.res_block_2(x4) + x6 = self.NL_scale_3(x5,mask) + x7 = self.after_NL(x6) + return x7 + + def inference_forward(self, input, mask): + x1 = self.before_NL(input) + del input + x2 = self.NL_scale_1.inference_forward(x1,mask) + del x1 + x3 = self.res_block_1(x2) + del x2 + x4 = self.NL_scale_2.inference_forward(x3,mask) + del x3 + x5 = self.res_block_2(x4) + del x4 + x6 = self.NL_scale_3.inference_forward(x5,mask) + del x5 + x7 = self.after_NL(x6) + del x6 + return x7 \ No newline at end of file diff --git a/Global/models/__init__.py b/Global/models/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Global/models/base_model.py b/Global/models/base_model.py new file mode 100755 index 0000000000000000000000000000000000000000..4043116050e057f31099cda3ecae6ee3fa46cb2a --- /dev/null +++ b/Global/models/base_model.py @@ -0,0 +1,122 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os +import torch +import sys + + +class BaseModel(torch.nn.Module): + def name(self): + return "BaseModel" + + def initialize(self, opt): + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) + + def set_input(self, input): + self.input = input + + def forward(self): + pass + + # used in test time, no backprop + def test(self): + pass + + def get_image_paths(self): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + return self.input + + def get_current_errors(self): + return {} + + def save(self, label): + pass + + # helper saving function that can be used by subclasses + def save_network(self, network, network_label, epoch_label, gpu_ids): + save_filename = "%s_net_%s.pth" % (epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + torch.save(network.cpu().state_dict(), save_path) + if len(gpu_ids) and torch.cuda.is_available(): + network.cuda() + + def save_optimizer(self, optimizer, optimizer_label, epoch_label): + save_filename = "%s_optimizer_%s.pth" % (epoch_label, optimizer_label) + save_path = os.path.join(self.save_dir, save_filename) + torch.save(optimizer.state_dict(), save_path) + + def load_optimizer(self, optimizer, optimizer_label, epoch_label, save_dir=""): + save_filename = "%s_optimizer_%s.pth" % (epoch_label, optimizer_label) + if not save_dir: + save_dir = self.save_dir + save_path = os.path.join(save_dir, save_filename) + + if not os.path.isfile(save_path): + print("%s not exists yet!" % save_path) + else: + optimizer.load_state_dict(torch.load(save_path)) + + # helper loading function that can be used by subclasses + def load_network(self, network, network_label, epoch_label, save_dir=""): + save_filename = "%s_net_%s.pth" % (epoch_label, network_label) + if not save_dir: + save_dir = self.save_dir + + # print(save_dir) + # print(self.save_dir) + save_path = os.path.join(save_dir, save_filename) + if not os.path.isfile(save_path): + print("%s not exists yet!" % save_path) + # if network_label == 'G': + # raise('Generator must exist!') + else: + # network.load_state_dict(torch.load(save_path)) + try: + # print(save_path) + network.load_state_dict(torch.load(save_path)) + except: + pretrained_dict = torch.load(save_path) + model_dict = network.state_dict() + try: + pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} + network.load_state_dict(pretrained_dict) + # if self.opt.verbose: + print( + "Pretrained network %s has excessive layers; Only loading layers that are used" + % network_label + ) + except: + print( + "Pretrained network %s has fewer layers; The following are not initialized:" + % network_label + ) + for k, v in pretrained_dict.items(): + if v.size() == model_dict[k].size(): + model_dict[k] = v + + if sys.version_info >= (3, 0): + not_initialized = set() + else: + from sets import Set + + not_initialized = Set() + + for k, v in model_dict.items(): + if k not in pretrained_dict or v.size() != pretrained_dict[k].size(): + not_initialized.add(k.split(".")[0]) + + print(sorted(not_initialized)) + network.load_state_dict(model_dict) + + def update_learning_rate(): + pass diff --git a/Global/models/mapping_model.py b/Global/models/mapping_model.py new file mode 100755 index 0000000000000000000000000000000000000000..e030f0f6274e9592494afbfaf17fa1d8371215ce --- /dev/null +++ b/Global/models/mapping_model.py @@ -0,0 +1,352 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import os +import functools +from torch.autograd import Variable +from util.image_pool import ImagePool +from .base_model import BaseModel +from . import networks +import math +from .NonLocal_feature_mapping_model import * + + +class Mapping_Model(nn.Module): + def __init__(self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None): + super(Mapping_Model, self).__init__() + + norm_layer = networks.get_norm_layer(norm_type=norm) + activation = nn.ReLU(True) + model = [] + tmp_nc = 64 + n_up = 4 + + print("Mapping: You are using the mapping model without global restoration.") + + for i in range(n_up): + ic = min(tmp_nc * (2 ** i), mc) + oc = min(tmp_nc * (2 ** (i + 1)), mc) + model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] + for i in range(n_blocks): + model += [ + networks.ResnetBlock( + mc, + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + opt=opt, + dilation=opt.mapping_net_dilation, + ) + ] + + for i in range(n_up - 1): + ic = min(64 * (2 ** (4 - i)), mc) + oc = min(64 * (2 ** (3 - i)), mc) + model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] + model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)] + if opt.feat_dim > 0 and opt.feat_dim < 64: + model += [norm_layer(tmp_nc), activation, nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1)] + # model += [nn.Conv2d(64, 1, 1, 1, 0)] + self.model = nn.Sequential(*model) + + def forward(self, input): + return self.model(input) + + +class Pix2PixHDModel_Mapping(BaseModel): + def name(self): + return "Pix2PixHDModel_Mapping" + + def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss, use_smooth_l1, stage_1_feat_l2): + flags = (True, True, use_gan_feat_loss, use_vgg_loss, True, True, use_smooth_l1, stage_1_feat_l2) + + def loss_filter(g_feat_l2, g_gan, g_gan_feat, g_vgg, d_real, d_fake, smooth_l1, stage_1_feat_l2): + return [ + l + for (l, f) in zip( + (g_feat_l2, g_gan, g_gan_feat, g_vgg, d_real, d_fake, smooth_l1, stage_1_feat_l2), flags + ) + if f + ] + + return loss_filter + + def initialize(self, opt): + BaseModel.initialize(self, opt) + if opt.resize_or_crop != "none" or not opt.isTrain: + torch.backends.cudnn.benchmark = True + self.isTrain = opt.isTrain + input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc + + ##### define networks + # Generator network + netG_input_nc = input_nc + self.netG_A = networks.GlobalGenerator_DCDCv2( + netG_input_nc, + opt.output_nc, + opt.ngf, + opt.k_size, + opt.n_downsample_global, + networks.get_norm_layer(norm_type=opt.norm), + opt=opt, + ) + self.netG_B = networks.GlobalGenerator_DCDCv2( + netG_input_nc, + opt.output_nc, + opt.ngf, + opt.k_size, + opt.n_downsample_global, + networks.get_norm_layer(norm_type=opt.norm), + opt=opt, + ) + + if opt.non_local == "Setting_42" or opt.NL_use_mask: + if opt.mapping_exp==1: + self.mapping_net = Mapping_Model_with_mask_2( + min(opt.ngf * 2 ** opt.n_downsample_global, opt.mc), + opt.map_mc, + n_blocks=opt.mapping_n_block, + opt=opt, + ) + else: + self.mapping_net = Mapping_Model_with_mask( + min(opt.ngf * 2 ** opt.n_downsample_global, opt.mc), + opt.map_mc, + n_blocks=opt.mapping_n_block, + opt=opt, + ) + else: + self.mapping_net = Mapping_Model( + min(opt.ngf * 2 ** opt.n_downsample_global, opt.mc), + opt.map_mc, + n_blocks=opt.mapping_n_block, + opt=opt, + ) + + self.mapping_net.apply(networks.weights_init) + + if opt.load_pretrain != "": + self.load_network(self.mapping_net, "mapping_net", opt.which_epoch, opt.load_pretrain) + + if not opt.no_load_VAE: + + self.load_network(self.netG_A, "G", opt.use_vae_which_epoch, opt.load_pretrainA) + self.load_network(self.netG_B, "G", opt.use_vae_which_epoch, opt.load_pretrainB) + for param in self.netG_A.parameters(): + param.requires_grad = False + for param in self.netG_B.parameters(): + param.requires_grad = False + self.netG_A.eval() + self.netG_B.eval() + + if opt.gpu_ids: + self.netG_A.cuda(opt.gpu_ids[0]) + self.netG_B.cuda(opt.gpu_ids[0]) + self.mapping_net.cuda(opt.gpu_ids[0]) + + if not self.isTrain: + self.load_network(self.mapping_net, "mapping_net", opt.which_epoch) + + # Discriminator network + if self.isTrain: + use_sigmoid = opt.no_lsgan + netD_input_nc = opt.ngf * 2 if opt.feat_gan else input_nc + opt.output_nc + if not opt.no_instance: + netD_input_nc += 1 + + self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt, opt.norm, use_sigmoid, + opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) + + # set loss functions and optimizers + if self.isTrain: + if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: + raise NotImplementedError("Fake Pool Not Implemented for MultiGPU") + self.fake_pool = ImagePool(opt.pool_size) + self.old_lr = opt.lr + + # define loss functions + self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss, opt.Smooth_L1, opt.use_two_stage_mapping) + + self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) + + + self.criterionFeat = torch.nn.L1Loss() + self.criterionFeat_feat = torch.nn.L1Loss() if opt.use_l1_feat else torch.nn.MSELoss() + + if self.opt.image_L1: + self.criterionImage=torch.nn.L1Loss() + else: + self.criterionImage = torch.nn.SmoothL1Loss() + + + print(self.criterionFeat_feat) + if not opt.no_vgg_loss: + self.criterionVGG = networks.VGGLoss_torch(self.gpu_ids) + + + # Names so we can breakout loss + self.loss_names = self.loss_filter('G_Feat_L2', 'G_GAN', 'G_GAN_Feat', 'G_VGG','D_real', 'D_fake', 'Smooth_L1', 'G_Feat_L2_Stage_1') + + # initialize optimizers + # optimizer G + + if opt.no_TTUR: + beta1,beta2=opt.beta1,0.999 + G_lr,D_lr=opt.lr,opt.lr + else: + beta1,beta2=0,0.9 + G_lr,D_lr=opt.lr/2,opt.lr*2 + + + if not opt.no_load_VAE: + params = list(self.mapping_net.parameters()) + self.optimizer_mapping = torch.optim.Adam(params, lr=G_lr, betas=(beta1, beta2)) + + # optimizer D + params = list(self.netD.parameters()) + self.optimizer_D = torch.optim.Adam(params, lr=D_lr, betas=(beta1, beta2)) + + print("---------- Optimizers initialized -------------") + + def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False): + if self.opt.label_nc == 0: + input_label = label_map.data.cuda() + else: + # create one-hot vector for label map + size = label_map.size() + oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) + input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() + input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) + if self.opt.data_type == 16: + input_label = input_label.half() + + # get edges from instance map + if not self.opt.no_instance: + inst_map = inst_map.data.cuda() + edge_map = self.get_edges(inst_map) + input_label = torch.cat((input_label, edge_map), dim=1) + input_label = Variable(input_label, volatile=infer) + + # real images for training + if real_image is not None: + real_image = Variable(real_image.data.cuda()) + + return input_label, inst_map, real_image, feat_map + + def discriminate(self, input_label, test_image, use_pool=False): + input_concat = torch.cat((input_label, test_image.detach()), dim=1) + if use_pool: + fake_query = self.fake_pool.query(input_concat) + return self.netD.forward(fake_query) + else: + return self.netD.forward(input_concat) + + def forward(self, label, inst, image, feat, pair=True, infer=False, last_label=None, last_image=None): + # Encode Inputs + input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat) + + # Fake Generation + input_concat = input_label + + label_feat = self.netG_A.forward(input_concat, flow='enc') + # print('label:') + # print(label_feat.min(), label_feat.max(), label_feat.mean()) + #label_feat = label_feat / 16.0 + + if self.opt.NL_use_mask: + label_feat_map=self.mapping_net(label_feat.detach(),inst) + else: + label_feat_map = self.mapping_net(label_feat.detach()) + + fake_image = self.netG_B.forward(label_feat_map, flow='dec') + image_feat = self.netG_B.forward(real_image, flow='enc') + + loss_feat_l2_stage_1=0 + loss_feat_l2 = self.criterionFeat_feat(label_feat_map, image_feat.data) * self.opt.l2_feat + + + if self.opt.feat_gan: + # Fake Detection and Loss + pred_fake_pool = self.discriminate(label_feat.detach(), label_feat_map, use_pool=True) + loss_D_fake = self.criterionGAN(pred_fake_pool, False) + + # Real Detection and Loss + pred_real = self.discriminate(label_feat.detach(), image_feat) + loss_D_real = self.criterionGAN(pred_real, True) + + # GAN loss (Fake Passability Loss) + pred_fake = self.netD.forward(torch.cat((label_feat.detach(), label_feat_map), dim=1)) + loss_G_GAN = self.criterionGAN(pred_fake, True) + else: + # Fake Detection and Loss + pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True) + loss_D_fake = self.criterionGAN(pred_fake_pool, False) + + # Real Detection and Loss + if pair: + pred_real = self.discriminate(input_label, real_image) + else: + pred_real = self.discriminate(last_label, last_image) + loss_D_real = self.criterionGAN(pred_real, True) + + # GAN loss (Fake Passability Loss) + pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1)) + loss_G_GAN = self.criterionGAN(pred_fake, True) + + # GAN feature matching loss + loss_G_GAN_Feat = 0 + if not self.opt.no_ganFeat_loss and pair: + feat_weights = 4.0 / (self.opt.n_layers_D + 1) + D_weights = 1.0 / self.opt.num_D + for i in range(self.opt.num_D): + for j in range(len(pred_fake[i])-1): + tmp = self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat + loss_G_GAN_Feat += D_weights * feat_weights * tmp + else: + loss_G_GAN_Feat = torch.zeros(1).to(label.device) + + # VGG feature matching loss + loss_G_VGG = 0 + if not self.opt.no_vgg_loss: + loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat if pair else torch.zeros(1).to(label.device) + + smooth_l1_loss=0 + if self.opt.Smooth_L1: + smooth_l1_loss=self.criterionImage(fake_image,real_image)*self.opt.L1_weight + + + return [ self.loss_filter(loss_feat_l2, loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake,smooth_l1_loss,loss_feat_l2_stage_1), None if not infer else fake_image ] + + + def inference(self, label, inst): + + use_gpu = len(self.opt.gpu_ids) > 0 + if use_gpu: + input_concat = label.data.cuda() + inst_data = inst.cuda() + else: + input_concat = label.data + inst_data = inst + + label_feat = self.netG_A.forward(input_concat, flow="enc") + + if self.opt.NL_use_mask: + if self.opt.inference_optimize: + label_feat_map=self.mapping_net.inference_forward(label_feat.detach(),inst_data) + else: + label_feat_map = self.mapping_net(label_feat.detach(), inst_data) + else: + label_feat_map = self.mapping_net(label_feat.detach()) + + fake_image = self.netG_B.forward(label_feat_map, flow="dec") + return fake_image + + +class InferenceModel(Pix2PixHDModel_Mapping): + def forward(self, label, inst): + return self.inference(label, inst) + diff --git a/Global/models/models.py b/Global/models/models.py new file mode 100755 index 0000000000000000000000000000000000000000..fd7defde2b02567e1903870c6744c072580ee938 --- /dev/null +++ b/Global/models/models.py @@ -0,0 +1,47 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch + + +def create_model(opt): + if opt.model == "pix2pixHD": + from .pix2pixHD_model import Pix2PixHDModel, InferenceModel + + if opt.isTrain: + model = Pix2PixHDModel() + else: + model = InferenceModel() + else: + from .ui_model import UIModel + + model = UIModel() + model.initialize(opt) + if opt.verbose: + print("model [%s] was created" % (model.name())) + + if opt.isTrain and len(opt.gpu_ids) > 1: + # pass + model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) + + return model + +def create_da_model(opt): + if opt.model == 'pix2pixHD': + from .pix2pixHD_model_DA import Pix2PixHDModel, InferenceModel + if opt.isTrain: + model = Pix2PixHDModel() + else: + model = InferenceModel() + else: + from .ui_model import UIModel + model = UIModel() + model.initialize(opt) + if opt.verbose: + print("model [%s] was created" % (model.name())) + + if opt.isTrain and len(opt.gpu_ids) > 1: + #pass + model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) + + return model \ No newline at end of file diff --git a/Global/models/networks.py b/Global/models/networks.py new file mode 100755 index 0000000000000000000000000000000000000000..6c4b08664b7ea139b310b658a63d2e46e61d8d75 --- /dev/null +++ b/Global/models/networks.py @@ -0,0 +1,875 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.nn as nn +import functools +from torch.autograd import Variable +import numpy as np +from torch.nn.utils import spectral_norm + +# from util.util import SwitchNorm2d +import torch.nn.functional as F + +############################################################################### +# Functions +############################################################################### +def weights_init(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(0.0, 0.02) + elif classname.find("BatchNorm2d") != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + + +def get_norm_layer(norm_type="instance"): + if norm_type == "batch": + norm_layer = functools.partial(nn.BatchNorm2d, affine=True) + elif norm_type == "instance": + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) + elif norm_type == "spectral": + norm_layer = spectral_norm() + elif norm_type == "SwitchNorm": + norm_layer = SwitchNorm2d + else: + raise NotImplementedError("normalization layer [%s] is not found" % norm_type) + return norm_layer + + +def print_network(net): + if isinstance(net, list): + net = net[0] + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + print(net) + print("Total number of parameters: %d" % num_params) + + +def define_G(input_nc, output_nc, ngf, netG, k_size=3, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1, + n_blocks_local=3, norm='instance', gpu_ids=[], opt=None): + + norm_layer = get_norm_layer(norm_type=norm) + if netG == 'global': + # if opt.self_gen: + if opt.use_v2: + netG = GlobalGenerator_DCDCv2(input_nc, output_nc, ngf, k_size, n_downsample_global, norm_layer, opt=opt) + else: + netG = GlobalGenerator_v2(input_nc, output_nc, ngf, k_size, n_downsample_global, n_blocks_global, norm_layer, opt=opt) + else: + raise('generator not implemented!') + print(netG) + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + netG.cuda(gpu_ids[0]) + netG.apply(weights_init) + return netG + + +def define_D(input_nc, ndf, n_layers_D, opt, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]): + norm_layer = get_norm_layer(norm_type=norm) + netD = MultiscaleDiscriminator(input_nc, opt, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat) + print(netD) + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + netD.cuda(gpu_ids[0]) + netD.apply(weights_init) + return netD + + + +class GlobalGenerator_DCDCv2(nn.Module): + def __init__( + self, + input_nc, + output_nc, + ngf=64, + k_size=3, + n_downsampling=8, + norm_layer=nn.BatchNorm2d, + padding_type="reflect", + opt=None, + ): + super(GlobalGenerator_DCDCv2, self).__init__() + activation = nn.ReLU(True) + + model = [ + nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, min(ngf, opt.mc), kernel_size=7, padding=0), + norm_layer(ngf), + activation, + ] + ### downsample + for i in range(opt.start_r): + mult = 2 ** i + model += [ + nn.Conv2d( + min(ngf * mult, opt.mc), + min(ngf * mult * 2, opt.mc), + kernel_size=k_size, + stride=2, + padding=1, + ), + norm_layer(min(ngf * mult * 2, opt.mc)), + activation, + ] + for i in range(opt.start_r, n_downsampling - 1): + mult = 2 ** i + model += [ + nn.Conv2d( + min(ngf * mult, opt.mc), + min(ngf * mult * 2, opt.mc), + kernel_size=k_size, + stride=2, + padding=1, + ), + norm_layer(min(ngf * mult * 2, opt.mc)), + activation, + ] + model += [ + ResnetBlock( + min(ngf * mult * 2, opt.mc), + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + opt=opt, + ) + ] + model += [ + ResnetBlock( + min(ngf * mult * 2, opt.mc), + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + opt=opt, + ) + ] + mult = 2 ** (n_downsampling - 1) + + if opt.spatio_size == 32: + model += [ + nn.Conv2d( + min(ngf * mult, opt.mc), + min(ngf * mult * 2, opt.mc), + kernel_size=k_size, + stride=2, + padding=1, + ), + norm_layer(min(ngf * mult * 2, opt.mc)), + activation, + ] + if opt.spatio_size == 64: + model += [ + ResnetBlock( + min(ngf * mult * 2, opt.mc), + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + opt=opt, + ) + ] + model += [ + ResnetBlock( + min(ngf * mult * 2, opt.mc), + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + opt=opt, + ) + ] + # model += [nn.Conv2d(min(ngf * mult * 2, opt.mc), min(ngf, opt.mc), 1, 1)] + if opt.feat_dim > 0: + model += [nn.Conv2d(min(ngf * mult * 2, opt.mc), opt.feat_dim, 1, 1)] + self.encoder = nn.Sequential(*model) + + # decode + model = [] + if opt.feat_dim > 0: + model += [nn.Conv2d(opt.feat_dim, min(ngf * mult * 2, opt.mc), 1, 1)] + # model += [nn.Conv2d(min(ngf, opt.mc), min(ngf * mult * 2, opt.mc), 1, 1)] + o_pad = 0 if k_size == 4 else 1 + mult = 2 ** n_downsampling + model += [ + ResnetBlock( + min(ngf * mult, opt.mc), + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + opt=opt, + ) + ] + + if opt.spatio_size == 32: + model += [ + nn.ConvTranspose2d( + min(ngf * mult, opt.mc), + min(int(ngf * mult / 2), opt.mc), + kernel_size=k_size, + stride=2, + padding=1, + output_padding=o_pad, + ), + norm_layer(min(int(ngf * mult / 2), opt.mc)), + activation, + ] + if opt.spatio_size == 64: + model += [ + ResnetBlock( + min(ngf * mult, opt.mc), + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + opt=opt, + ) + ] + + for i in range(1, n_downsampling - opt.start_r): + mult = 2 ** (n_downsampling - i) + model += [ + ResnetBlock( + min(ngf * mult, opt.mc), + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + opt=opt, + ) + ] + model += [ + ResnetBlock( + min(ngf * mult, opt.mc), + padding_type=padding_type, + activation=activation, + norm_layer=norm_layer, + opt=opt, + ) + ] + model += [ + nn.ConvTranspose2d( + min(ngf * mult, opt.mc), + min(int(ngf * mult / 2), opt.mc), + kernel_size=k_size, + stride=2, + padding=1, + output_padding=o_pad, + ), + norm_layer(min(int(ngf * mult / 2), opt.mc)), + activation, + ] + for i in range(n_downsampling - opt.start_r, n_downsampling): + mult = 2 ** (n_downsampling - i) + model += [ + nn.ConvTranspose2d( + min(ngf * mult, opt.mc), + min(int(ngf * mult / 2), opt.mc), + kernel_size=k_size, + stride=2, + padding=1, + output_padding=o_pad, + ), + norm_layer(min(int(ngf * mult / 2), opt.mc)), + activation, + ] + if opt.use_segmentation_model: + model += [nn.ReflectionPad2d(3), nn.Conv2d(min(ngf, opt.mc), output_nc, kernel_size=7, padding=0)] + else: + model += [ + nn.ReflectionPad2d(3), + nn.Conv2d(min(ngf, opt.mc), output_nc, kernel_size=7, padding=0), + nn.Tanh(), + ] + self.decoder = nn.Sequential(*model) + + def forward(self, input, flow="enc_dec"): + if flow == "enc": + return self.encoder(input) + elif flow == "dec": + return self.decoder(input) + elif flow == "enc_dec": + x = self.encoder(input) + x = self.decoder(x) + return x + + +# Define a resnet block +class ResnetBlock(nn.Module): + def __init__( + self, dim, padding_type, norm_layer, opt, activation=nn.ReLU(True), use_dropout=False, dilation=1 + ): + super(ResnetBlock, self).__init__() + self.opt = opt + self.dilation = dilation + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout) + + def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout): + conv_block = [] + p = 0 + if padding_type == "reflect": + conv_block += [nn.ReflectionPad2d(self.dilation)] + elif padding_type == "replicate": + conv_block += [nn.ReplicationPad2d(self.dilation)] + elif padding_type == "zero": + p = self.dilation + else: + raise NotImplementedError("padding [%s] is not implemented" % padding_type) + + conv_block += [ + nn.Conv2d(dim, dim, kernel_size=3, padding=p, dilation=self.dilation), + norm_layer(dim), + activation, + ] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == "reflect": + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == "replicate": + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == "zero": + p = 1 + else: + raise NotImplementedError("padding [%s] is not implemented" % padding_type) + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, dilation=1), norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = x + self.conv_block(x) + return out + + +class Encoder(nn.Module): + def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d): + super(Encoder, self).__init__() + self.output_nc = output_nc + + model = [ + nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), + norm_layer(ngf), + nn.ReLU(True), + ] + ### downsample + for i in range(n_downsampling): + mult = 2 ** i + model += [ + nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), + norm_layer(ngf * mult * 2), + nn.ReLU(True), + ] + + ### upsample + for i in range(n_downsampling): + mult = 2 ** (n_downsampling - i) + model += [ + nn.ConvTranspose2d( + ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1 + ), + norm_layer(int(ngf * mult / 2)), + nn.ReLU(True), + ] + + model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] + self.model = nn.Sequential(*model) + + def forward(self, input, inst): + outputs = self.model(input) + + # instance-wise average pooling + outputs_mean = outputs.clone() + inst_list = np.unique(inst.cpu().numpy().astype(int)) + for i in inst_list: + for b in range(input.size()[0]): + indices = (inst[b : b + 1] == int(i)).nonzero() # n x 4 + for j in range(self.output_nc): + output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], indices[:, 3]] + mean_feat = torch.mean(output_ins).expand_as(output_ins) + outputs_mean[ + indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], indices[:, 3] + ] = mean_feat + return outputs_mean + + +def SN(module, mode=True): + if mode: + return torch.nn.utils.spectral_norm(module) + + return module + + +class NonLocalBlock2D_with_mask_Res(nn.Module): + def __init__( + self, + in_channels, + inter_channels, + mode="add", + re_norm=False, + temperature=1.0, + use_self=False, + cosin=False, + ): + super(NonLocalBlock2D_with_mask_Res, self).__init__() + + self.cosin = cosin + self.renorm = re_norm + self.in_channels = in_channels + self.inter_channels = inter_channels + + self.g = nn.Conv2d( + in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 + ) + + self.W = nn.Conv2d( + in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0 + ) + # for pytorch 0.3.1 + # nn.init.constant(self.W.weight, 0) + # nn.init.constant(self.W.bias, 0) + # for pytorch 0.4.0 + nn.init.constant_(self.W.weight, 0) + nn.init.constant_(self.W.bias, 0) + self.theta = nn.Conv2d( + in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 + ) + + self.phi = nn.Conv2d( + in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 + ) + + self.mode = mode + self.temperature = temperature + self.use_self = use_self + + norm_layer = get_norm_layer(norm_type="instance") + activation = nn.ReLU(True) + + model = [] + for i in range(3): + model += [ + ResnetBlock( + inter_channels, + padding_type="reflect", + activation=activation, + norm_layer=norm_layer, + opt=None, + ) + ] + self.res_block = nn.Sequential(*model) + + def forward(self, x, mask): ## The shape of mask is Batch*1*H*W + batch_size = x.size(0) + + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + + g_x = g_x.permute(0, 2, 1) + + theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) + + theta_x = theta_x.permute(0, 2, 1) + + phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) + + if self.cosin: + theta_x = F.normalize(theta_x, dim=2) + phi_x = F.normalize(phi_x, dim=1) + + f = torch.matmul(theta_x, phi_x) + + f /= self.temperature + + f_div_C = F.softmax(f, dim=2) + + tmp = 1 - mask + mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear") + mask[mask > 0] = 1.0 + mask = 1 - mask + + tmp = F.interpolate(tmp, (x.size(2), x.size(3))) + mask *= tmp + + mask_expand = mask.view(batch_size, 1, -1) + mask_expand = mask_expand.repeat(1, x.size(2) * x.size(3), 1) + + # mask = 1 - mask + # mask=F.interpolate(mask,(x.size(2),x.size(3))) + # mask_expand=mask.view(batch_size,1,-1) + # mask_expand=mask_expand.repeat(1,x.size(2)*x.size(3),1) + + if self.use_self: + mask_expand[:, range(x.size(2) * x.size(3)), range(x.size(2) * x.size(3))] = 1.0 + + # print(mask_expand.shape) + # print(f_div_C.shape) + + f_div_C = mask_expand * f_div_C + if self.renorm: + f_div_C = F.normalize(f_div_C, p=1, dim=2) + + ########################### + + y = torch.matmul(f_div_C, g_x) + + y = y.permute(0, 2, 1).contiguous() + + y = y.view(batch_size, self.inter_channels, *x.size()[2:]) + W_y = self.W(y) + + W_y = self.res_block(W_y) + + if self.mode == "combine": + full_mask = mask.repeat(1, self.inter_channels, 1, 1) + z = full_mask * x + (1 - full_mask) * W_y + return z + + +class MultiscaleDiscriminator(nn.Module): + def __init__(self, input_nc, opt, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, + use_sigmoid=False, num_D=3, getIntermFeat=False): + super(MultiscaleDiscriminator, self).__init__() + self.num_D = num_D + self.n_layers = n_layers + self.getIntermFeat = getIntermFeat + + for i in range(num_D): + netD = NLayerDiscriminator(input_nc, opt, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat) + if getIntermFeat: + for j in range(n_layers+2): + setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j))) + else: + setattr(self, 'layer'+str(i), netD.model) + + self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) + + def singleD_forward(self, model, input): + if self.getIntermFeat: + result = [input] + for i in range(len(model)): + result.append(model[i](result[-1])) + return result[1:] + else: + return [model(input)] + + def forward(self, input): + num_D = self.num_D + result = [] + input_downsampled = input + for i in range(num_D): + if self.getIntermFeat: + model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)] + else: + model = getattr(self, 'layer'+str(num_D-1-i)) + result.append(self.singleD_forward(model, input_downsampled)) + if i != (num_D-1): + input_downsampled = self.downsample(input_downsampled) + return result + +# Defines the PatchGAN discriminator with the specified arguments. +class NLayerDiscriminator(nn.Module): + def __init__(self, input_nc, opt, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False): + super(NLayerDiscriminator, self).__init__() + self.getIntermFeat = getIntermFeat + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[SN(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),opt.use_SN), nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + SN(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),opt.use_SN), + norm_layer(nf), nn.LeakyReLU(0.2, True) + ]] + + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + SN(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),opt.use_SN), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ]] + + sequence += [[SN(nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw),opt.use_SN)]] + + if use_sigmoid: + sequence += [[nn.Sigmoid()]] + + if getIntermFeat: + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + else: + sequence_stream = [] + for n in range(len(sequence)): + sequence_stream += sequence[n] + self.model = nn.Sequential(*sequence_stream) + + def forward(self, input): + if self.getIntermFeat: + res = [input] + for n in range(self.n_layers+2): + model = getattr(self, 'model'+str(n)) + res.append(model(res[-1])) + return res[1:] + else: + return self.model(input) + + + +class Patch_Attention_4(nn.Module): ## While combine the feature map, use conv and mask + def __init__(self, in_channels, inter_channels, patch_size): + super(Patch_Attention_4, self).__init__() + + self.patch_size=patch_size + + + # self.g = nn.Conv2d( + # in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 + # ) + + # self.W = nn.Conv2d( + # in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0 + # ) + # # for pytorch 0.3.1 + # # nn.init.constant(self.W.weight, 0) + # # nn.init.constant(self.W.bias, 0) + # # for pytorch 0.4.0 + # nn.init.constant_(self.W.weight, 0) + # nn.init.constant_(self.W.bias, 0) + # self.theta = nn.Conv2d( + # in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 + # ) + + # self.phi = nn.Conv2d( + # in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 + # ) + + self.F_Combine=nn.Conv2d(in_channels=1025,out_channels=512,kernel_size=3,stride=1,padding=1,bias=True) + norm_layer = get_norm_layer(norm_type="instance") + activation = nn.ReLU(True) + + model = [] + for i in range(1): + model += [ + ResnetBlock( + inter_channels, + padding_type="reflect", + activation=activation, + norm_layer=norm_layer, + opt=None, + ) + ] + self.res_block = nn.Sequential(*model) + + def Hard_Compose(self, input, dim, index): + # batch index select + # input: [B,C,HW] + # dim: scalar > 0 + # index: [B, HW] + views = [input.size(0)] + [1 if i!=dim else -1 for i in range(1, len(input.size()))] + expanse = list(input.size()) + expanse[0] = -1 + expanse[dim] = -1 + index = index.view(views).expand(expanse) + return torch.gather(input, dim, index) + + def forward(self, z, mask): ## The shape of mask is Batch*1*H*W + + x=self.res_block(z) + + b,c,h,w=x.shape + + ## mask resize + dilation + # tmp = 1 - mask + mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear") + mask[mask > 0] = 1.0 + + # mask = 1 - mask + # tmp = F.interpolate(tmp, (x.size(2), x.size(3))) + # mask *= tmp + # mask=1-mask + ## 1: mask position 0: non-mask + + mask_unfold=F.unfold(mask, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size) + non_mask_region=(torch.mean(mask_unfold,dim=1,keepdim=True)>0.6).float() + all_patch_num=h*w/self.patch_size/self.patch_size + non_mask_region=non_mask_region.repeat(1,int(all_patch_num),1) + + x_unfold=F.unfold(x, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size) + y_unfold=x_unfold.permute(0,2,1) + x_unfold_normalized=F.normalize(x_unfold,dim=1) + y_unfold_normalized=F.normalize(y_unfold,dim=2) + correlation_matrix=torch.bmm(y_unfold_normalized,x_unfold_normalized) + correlation_matrix=correlation_matrix.masked_fill(non_mask_region==1.,-1e9) + correlation_matrix=F.softmax(correlation_matrix,dim=2) + + # print(correlation_matrix) + + R, max_arg=torch.max(correlation_matrix,dim=2) + + composed_unfold=self.Hard_Compose(x_unfold, 2, max_arg) + composed_fold=F.fold(composed_unfold,output_size=(h,w),kernel_size=(self.patch_size,self.patch_size),padding=0,stride=self.patch_size) + + concat_1=torch.cat((z,composed_fold,mask),dim=1) + concat_1=self.F_Combine(concat_1) + + return concat_1 + + def inference_forward(self,z,mask): ## Reduce the extra memory cost + + + x=self.res_block(z) + + b,c,h,w=x.shape + + ## mask resize + dilation + # tmp = 1 - mask + mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear") + mask[mask > 0] = 1.0 + # mask = 1 - mask + # tmp = F.interpolate(tmp, (x.size(2), x.size(3))) + # mask *= tmp + # mask=1-mask + ## 1: mask position 0: non-mask + + mask_unfold=F.unfold(mask, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size) + non_mask_region=(torch.mean(mask_unfold,dim=1,keepdim=True)>0.6).float()[0,0,:] # 1*1*all_patch_num + + all_patch_num=h*w/self.patch_size/self.patch_size + + mask_index=torch.nonzero(non_mask_region,as_tuple=True)[0] + + + if len(mask_index)==0: ## No mask patch is selected, no attention is needed + + composed_fold=x + + else: + + unmask_index=torch.nonzero(non_mask_region!=1,as_tuple=True)[0] + + x_unfold=F.unfold(x, kernel_size=(self.patch_size,self.patch_size), padding=0, stride=self.patch_size) + + Query_Patch=torch.index_select(x_unfold,2,mask_index) + Key_Patch=torch.index_select(x_unfold,2,unmask_index) + + Query_Patch=Query_Patch.permute(0,2,1) + Query_Patch_normalized=F.normalize(Query_Patch,dim=2) + Key_Patch_normalized=F.normalize(Key_Patch,dim=1) + + correlation_matrix=torch.bmm(Query_Patch_normalized,Key_Patch_normalized) + correlation_matrix=F.softmax(correlation_matrix,dim=2) + + + R, max_arg=torch.max(correlation_matrix,dim=2) + + composed_unfold=self.Hard_Compose(Key_Patch, 2, max_arg) + x_unfold[:,:,mask_index]=composed_unfold + composed_fold=F.fold(x_unfold,output_size=(h,w),kernel_size=(self.patch_size,self.patch_size),padding=0,stride=self.patch_size) + + concat_1=torch.cat((z,composed_fold,mask),dim=1) + concat_1=self.F_Combine(concat_1) + + + return concat_1 + +############################################################################## +# Losses +############################################################################## +class GANLoss(nn.Module): + def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, + tensor=torch.FloatTensor): + super(GANLoss, self).__init__() + self.real_label = target_real_label + self.fake_label = target_fake_label + self.real_label_var = None + self.fake_label_var = None + self.Tensor = tensor + if use_lsgan: + self.loss = nn.MSELoss() + else: + self.loss = nn.BCELoss() + + def get_target_tensor(self, input, target_is_real): + target_tensor = None + if target_is_real: + create_label = ((self.real_label_var is None) or + (self.real_label_var.numel() != input.numel())) + if create_label: + real_tensor = self.Tensor(input.size()).fill_(self.real_label) + self.real_label_var = Variable(real_tensor, requires_grad=False) + target_tensor = self.real_label_var + else: + create_label = ((self.fake_label_var is None) or + (self.fake_label_var.numel() != input.numel())) + if create_label: + fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) + self.fake_label_var = Variable(fake_tensor, requires_grad=False) + target_tensor = self.fake_label_var + return target_tensor + + def __call__(self, input, target_is_real): + if isinstance(input[0], list): + loss = 0 + for input_i in input: + pred = input_i[-1] + target_tensor = self.get_target_tensor(pred, target_is_real) + loss += self.loss(pred, target_tensor) + return loss + else: + target_tensor = self.get_target_tensor(input[-1], target_is_real) + return self.loss(input[-1], target_tensor) + + + + +####################################### VGG Loss + +from torchvision import models +class VGG19_torch(torch.nn.Module): + def __init__(self, requires_grad=False): + super(VGG19_torch, self).__init__() + vgg_pretrained_features = models.vgg19(pretrained=True).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(2): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(2, 7): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(7, 12): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 21): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(21, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h_relu1 = self.slice1(X) + h_relu2 = self.slice2(h_relu1) + h_relu3 = self.slice3(h_relu2) + h_relu4 = self.slice4(h_relu3) + h_relu5 = self.slice5(h_relu4) + out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] + return out + +class VGGLoss_torch(nn.Module): + def __init__(self, gpu_ids): + super(VGGLoss_torch, self).__init__() + self.vgg = VGG19_torch().cuda() + self.criterion = nn.L1Loss() + self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] + + def forward(self, x, y): + x_vgg, y_vgg = self.vgg(x), self.vgg(y) + loss = 0 + for i in range(len(x_vgg)): + loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) + return loss \ No newline at end of file diff --git a/Global/models/pix2pixHD_model.py b/Global/models/pix2pixHD_model.py new file mode 100644 index 0000000000000000000000000000000000000000..edf829f7340d15007ab7563e88ace974cf8d08ee --- /dev/null +++ b/Global/models/pix2pixHD_model.py @@ -0,0 +1,333 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import numpy as np +import torch +import os +from torch.autograd import Variable +from util.image_pool import ImagePool +from .base_model import BaseModel +from . import networks + +class Pix2PixHDModel(BaseModel): + def name(self): + return 'Pix2PixHDModel' + + def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss,use_smooth_L1): + flags = (True, use_gan_feat_loss, use_vgg_loss, True, True, True,use_smooth_L1) + def loss_filter(g_gan, g_gan_feat, g_vgg, g_kl, d_real, d_fake,smooth_l1): + return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg, g_kl, d_real,d_fake,smooth_l1),flags) if f] + return loss_filter + + def initialize(self, opt): + BaseModel.initialize(self, opt) + if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM + torch.backends.cudnn.benchmark = True + self.isTrain = opt.isTrain + self.use_features = opt.instance_feat or opt.label_feat ## Clearly it is false + self.gen_features = self.use_features and not self.opt.load_features ## it is also false + input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc ## Just is the origin input channel # + + ##### define networks + # Generator network + netG_input_nc = input_nc + if not opt.no_instance: + netG_input_nc += 1 + if self.use_features: + netG_input_nc += opt.feat_num + self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, opt.k_size, + opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, + opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids, opt=opt) + + # Discriminator network + if self.isTrain: + use_sigmoid = opt.no_lsgan + netD_input_nc = opt.output_nc if opt.no_cgan else input_nc + opt.output_nc + if not opt.no_instance: + netD_input_nc += 1 + self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt, opt.norm, use_sigmoid, + opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) + + if self.opt.verbose: + print('---------- Networks initialized -------------') + + # load networks + if not self.isTrain or opt.continue_train or opt.load_pretrain: + pretrained_path = '' if not self.isTrain else opt.load_pretrain + self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) + + print("---------- G Networks reloaded -------------") + if self.isTrain: + self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) + print("---------- D Networks reloaded -------------") + + + if self.gen_features: + self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path) + + # set loss functions and optimizers + if self.isTrain: + if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: ## The pool_size is 0! + raise NotImplementedError("Fake Pool Not Implemented for MultiGPU") + self.fake_pool = ImagePool(opt.pool_size) + self.old_lr = opt.lr + + # define loss functions + self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss, opt.Smooth_L1) + + self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) + self.criterionFeat = torch.nn.L1Loss() + + # self.criterionImage = torch.nn.SmoothL1Loss() + if not opt.no_vgg_loss: + self.criterionVGG = networks.VGGLoss_torch(self.gpu_ids) + + + self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG', 'G_KL', 'D_real', 'D_fake', 'Smooth_L1') + + # initialize optimizers + # optimizer G + params = list(self.netG.parameters()) + if self.gen_features: + params += list(self.netE.parameters()) + self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) + + # optimizer D + params = list(self.netD.parameters()) + self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) + + print("---------- Optimizers initialized -------------") + + if opt.continue_train: + self.load_optimizer(self.optimizer_D, 'D', opt.which_epoch) + self.load_optimizer(self.optimizer_G, "G", opt.which_epoch) + for param_groups in self.optimizer_D.param_groups: + self.old_lr=param_groups['lr'] + + print("---------- Optimizers reloaded -------------") + print("---------- Current LR is %.8f -------------"%(self.old_lr)) + + ## We also want to re-load the parameters of optimizer. + + def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False): + if self.opt.label_nc == 0: + input_label = label_map.data.cuda() + else: + # create one-hot vector for label map + size = label_map.size() + oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) + input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() + input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) + if self.opt.data_type == 16: + input_label = input_label.half() + + # get edges from instance map + if not self.opt.no_instance: + inst_map = inst_map.data.cuda() + edge_map = self.get_edges(inst_map) + input_label = torch.cat((input_label, edge_map), dim=1) + input_label = Variable(input_label, volatile=infer) + + # real images for training + if real_image is not None: + real_image = Variable(real_image.data.cuda()) + + # instance map for feature encoding + if self.use_features: + # get precomputed feature maps + if self.opt.load_features: + feat_map = Variable(feat_map.data.cuda()) + if self.opt.label_feat: + inst_map = label_map.cuda() + + return input_label, inst_map, real_image, feat_map + + def discriminate(self, input_label, test_image, use_pool=False): + if input_label is None: + input_concat = test_image.detach() + else: + input_concat = torch.cat((input_label, test_image.detach()), dim=1) + if use_pool: + fake_query = self.fake_pool.query(input_concat) + return self.netD.forward(fake_query) + else: + return self.netD.forward(input_concat) + + def forward(self, label, inst, image, feat, infer=False): + # Encode Inputs + input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat) + + # Fake Generation + if self.use_features: + if not self.opt.load_features: + feat_map = self.netE.forward(real_image, inst_map) + input_concat = torch.cat((input_label, feat_map), dim=1) + else: + input_concat = input_label + hiddens = self.netG.forward(input_concat, 'enc') + noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device())) + # This is a reduced VAE implementation where we assume the outputs are multivariate Gaussian distribution with mean = hiddens and std_dev = all ones. + # We follow the the VAE of MUNIT (https://github.com/NVlabs/MUNIT/blob/master/networks.py) + fake_image = self.netG.forward(hiddens + noise, 'dec') + + if self.opt.no_cgan: + # Fake Detection and Loss + pred_fake_pool = self.discriminate(None, fake_image, use_pool=True) + loss_D_fake = self.criterionGAN(pred_fake_pool, False) + + # Real Detection and Loss + pred_real = self.discriminate(None, real_image) + loss_D_real = self.criterionGAN(pred_real, True) + + # GAN loss (Fake Passability Loss) + pred_fake = self.netD.forward(fake_image) + loss_G_GAN = self.criterionGAN(pred_fake, True) + else: + # Fake Detection and Loss + pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True) + loss_D_fake = self.criterionGAN(pred_fake_pool, False) + + # Real Detection and Loss + pred_real = self.discriminate(input_label, real_image) + loss_D_real = self.criterionGAN(pred_real, True) + + # GAN loss (Fake Passability Loss) + pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1)) + loss_G_GAN = self.criterionGAN(pred_fake, True) + + + loss_G_kl = torch.mean(torch.pow(hiddens, 2)) * self.opt.kl + + # GAN feature matching loss + loss_G_GAN_Feat = 0 + if not self.opt.no_ganFeat_loss: + feat_weights = 4.0 / (self.opt.n_layers_D + 1) + D_weights = 1.0 / self.opt.num_D + for i in range(self.opt.num_D): + for j in range(len(pred_fake[i])-1): + loss_G_GAN_Feat += D_weights * feat_weights * \ + self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat + + # VGG feature matching loss + loss_G_VGG = 0 + if not self.opt.no_vgg_loss: + loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat + + + smooth_l1_loss=0 + + return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_G_kl, loss_D_real, loss_D_fake,smooth_l1_loss ), None if not infer else fake_image ] + + def inference(self, label, inst, image=None, feat=None): + # Encode Inputs + image = Variable(image) if image is not None else None + input_label, inst_map, real_image, _ = self.encode_input(Variable(label), Variable(inst), image, infer=True) + + # Fake Generation + if self.use_features: + if self.opt.use_encoded_image: + # encode the real image to get feature map + feat_map = self.netE.forward(real_image, inst_map) + else: + # sample clusters from precomputed features + feat_map = self.sample_features(inst_map) + input_concat = torch.cat((input_label, feat_map), dim=1) + else: + input_concat = input_label + + if torch.__version__.startswith('0.4'): + with torch.no_grad(): + fake_image = self.netG.forward(input_concat) + else: + fake_image = self.netG.forward(input_concat) + return fake_image + + def sample_features(self, inst): + # read precomputed feature clusters + cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path) + features_clustered = np.load(cluster_path, encoding='latin1').item() + + # randomly sample from the feature clusters + inst_np = inst.cpu().numpy().astype(int) + feat_map = self.Tensor(inst.size()[0], self.opt.feat_num, inst.size()[2], inst.size()[3]) + for i in np.unique(inst_np): + label = i if i < 1000 else i//1000 + if label in features_clustered: + feat = features_clustered[label] + cluster_idx = np.random.randint(0, feat.shape[0]) + + idx = (inst == int(i)).nonzero() + for k in range(self.opt.feat_num): + feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k] + if self.opt.data_type==16: + feat_map = feat_map.half() + return feat_map + + def encode_features(self, image, inst): + image = Variable(image.cuda(), volatile=True) + feat_num = self.opt.feat_num + h, w = inst.size()[2], inst.size()[3] + block_num = 32 + feat_map = self.netE.forward(image, inst.cuda()) + inst_np = inst.cpu().numpy().astype(int) + feature = {} + for i in range(self.opt.label_nc): + feature[i] = np.zeros((0, feat_num+1)) + for i in np.unique(inst_np): + label = i if i < 1000 else i//1000 + idx = (inst == int(i)).nonzero() + num = idx.size()[0] + idx = idx[num//2,:] + val = np.zeros((1, feat_num+1)) + for k in range(feat_num): + val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0] + val[0, feat_num] = float(num) / (h * w // block_num) + feature[label] = np.append(feature[label], val, axis=0) + return feature + + def get_edges(self, t): + edge = torch.cuda.ByteTensor(t.size()).zero_() + edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1]) + edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1]) + edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) + edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) + if self.opt.data_type==16: + return edge.half() + else: + return edge.float() + + def save(self, which_epoch): + self.save_network(self.netG, 'G', which_epoch, self.gpu_ids) + self.save_network(self.netD, 'D', which_epoch, self.gpu_ids) + + self.save_optimizer(self.optimizer_G,"G",which_epoch) + self.save_optimizer(self.optimizer_D,"D",which_epoch) + + if self.gen_features: + self.save_network(self.netE, 'E', which_epoch, self.gpu_ids) + + def update_fixed_params(self): + + params = list(self.netG.parameters()) + if self.gen_features: + params += list(self.netE.parameters()) + self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) + if self.opt.verbose: + print('------------ Now also finetuning global generator -----------') + + def update_learning_rate(self): + lrd = self.opt.lr / self.opt.niter_decay + lr = self.old_lr - lrd + for param_group in self.optimizer_D.param_groups: + param_group['lr'] = lr + for param_group in self.optimizer_G.param_groups: + param_group['lr'] = lr + if self.opt.verbose: + print('update learning rate: %f -> %f' % (self.old_lr, lr)) + self.old_lr = lr + + +class InferenceModel(Pix2PixHDModel): + def forward(self, inp): + label, inst = inp + return self.inference(label, inst) diff --git a/Global/models/pix2pixHD_model_DA.py b/Global/models/pix2pixHD_model_DA.py new file mode 100644 index 0000000000000000000000000000000000000000..617589df30ef1d808115332f76a77acaaeba099c --- /dev/null +++ b/Global/models/pix2pixHD_model_DA.py @@ -0,0 +1,372 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import numpy as np +import torch +import os +from torch.autograd import Variable +from util.image_pool import ImagePool +from .base_model import BaseModel +from . import networks + + +class Pix2PixHDModel(BaseModel): + def name(self): + return 'Pix2PixHDModel' + + def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss): + flags = (True, use_gan_feat_loss, use_vgg_loss, True, True, True, True, True, True) + + def loss_filter(g_gan, g_gan_feat, g_vgg, g_kl, d_real, d_fake, g_featd, featd_real, featd_fake): + return [l for (l, f) in zip((g_gan, g_gan_feat, g_vgg, g_kl, d_real, d_fake, g_featd, featd_real, featd_fake), flags) if f] + + return loss_filter + + def initialize(self, opt): + BaseModel.initialize(self, opt) + if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM + torch.backends.cudnn.benchmark = True + self.isTrain = opt.isTrain + self.use_features = opt.instance_feat or opt.label_feat ## Clearly it is false + self.gen_features = self.use_features and not self.opt.load_features ## it is also false + input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc ## Just is the origin input channel # + + ##### define networks + # Generator network + netG_input_nc = input_nc + if not opt.no_instance: + netG_input_nc += 1 + if self.use_features: + netG_input_nc += opt.feat_num + self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, opt.k_size, + opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, + opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids, opt=opt) + + # Discriminator network + if self.isTrain: + use_sigmoid = opt.no_lsgan + netD_input_nc = opt.output_nc if opt.no_cgan else input_nc + opt.output_nc + if not opt.no_instance: + netD_input_nc += 1 + self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt,opt.norm, use_sigmoid, + opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) + + self.feat_D=networks.define_D(64, opt.ndf, opt.n_layers_D, opt, opt.norm, use_sigmoid, + 1, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) + + if self.opt.verbose: + print('---------- Networks initialized -------------') + + # load networks + if not self.isTrain or opt.continue_train or opt.load_pretrain: + pretrained_path = '' if not self.isTrain else opt.load_pretrain + self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) + + print("---------- G Networks reloaded -------------") + if self.isTrain: + self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) + self.load_network(self.feat_D, 'feat_D', opt.which_epoch, pretrained_path) + print("---------- D Networks reloaded -------------") + + + # set loss functions and optimizers + if self.isTrain: + if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: ## The pool_size is 0! + raise NotImplementedError("Fake Pool Not Implemented for MultiGPU") + self.fake_pool = ImagePool(opt.pool_size) + self.old_lr = opt.lr + + # define loss functions + self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss) + + self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) + self.criterionFeat = torch.nn.L1Loss() + if not opt.no_vgg_loss: + self.criterionVGG = networks.VGGLoss_torch(self.gpu_ids) + + # Names so we can breakout loss + self.loss_names = self.loss_filter('G_GAN', 'G_GAN_Feat', 'G_VGG', 'G_KL', 'D_real', 'D_fake', 'G_featD', 'featD_real','featD_fake') + + # initialize optimizers + # optimizer G + params = list(self.netG.parameters()) + if self.gen_features: + params += list(self.netE.parameters()) + self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) + + # optimizer D + params = list(self.netD.parameters()) + self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) + + params = list(self.feat_D.parameters()) + self.optimizer_featD = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) + + print("---------- Optimizers initialized -------------") + + if opt.continue_train: + self.load_optimizer(self.optimizer_D, 'D', opt.which_epoch) + self.load_optimizer(self.optimizer_G, "G", opt.which_epoch) + self.load_optimizer(self.optimizer_featD,'featD',opt.which_epoch) + for param_groups in self.optimizer_D.param_groups: + self.old_lr = param_groups['lr'] + + print("---------- Optimizers reloaded -------------") + print("---------- Current LR is %.8f -------------" % (self.old_lr)) + + ## We also want to re-load the parameters of optimizer. + + def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False): + if self.opt.label_nc == 0: + input_label = label_map.data.cuda() + else: + # create one-hot vector for label map + size = label_map.size() + oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) + input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() + input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) + if self.opt.data_type == 16: + input_label = input_label.half() + + # get edges from instance map + if not self.opt.no_instance: + inst_map = inst_map.data.cuda() + edge_map = self.get_edges(inst_map) + input_label = torch.cat((input_label, edge_map), dim=1) + input_label = Variable(input_label, volatile=infer) + + # real images for training + if real_image is not None: + real_image = Variable(real_image.data.cuda()) + + # instance map for feature encoding + if self.use_features: + # get precomputed feature maps + if self.opt.load_features: + feat_map = Variable(feat_map.data.cuda()) + if self.opt.label_feat: + inst_map = label_map.cuda() + + return input_label, inst_map, real_image, feat_map + + def discriminate(self, input_label, test_image, use_pool=False): + if input_label is None: + input_concat = test_image.detach() + else: + input_concat = torch.cat((input_label, test_image.detach()), dim=1) + if use_pool: + fake_query = self.fake_pool.query(input_concat) + return self.netD.forward(fake_query) + else: + return self.netD.forward(input_concat) + + def feat_discriminate(self,input): + + return self.feat_D.forward(input.detach()) + + + def forward(self, label, inst, image, feat, infer=False): + # Encode Inputs + input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat) + + # Fake Generation + if self.use_features: + if not self.opt.load_features: + feat_map = self.netE.forward(real_image, inst_map) + input_concat = torch.cat((input_label, feat_map), dim=1) + else: + input_concat = input_label + hiddens = self.netG.forward(input_concat, 'enc') + noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device())) + # This is a reduced VAE implementation where we assume the outputs are multivariate Gaussian distribution with mean = hiddens and std_dev = all ones. + # We follow the the VAE of MUNIT (https://github.com/NVlabs/MUNIT/blob/master/networks.py) + fake_image = self.netG.forward(hiddens + noise, 'dec') + + #################### + ##### GAN for the intermediate feature + real_old_feat =[] + syn_feat = [] + for index,x in enumerate(inst): + if x==1: + real_old_feat.append(hiddens[index].unsqueeze(0)) + else: + syn_feat.append(hiddens[index].unsqueeze(0)) + L=min(len(real_old_feat),len(syn_feat)) + real_old_feat=real_old_feat[:L] + syn_feat=syn_feat[:L] + real_old_feat=torch.cat(real_old_feat,0) + syn_feat=torch.cat(syn_feat,0) + + pred_fake_feat=self.feat_discriminate(real_old_feat) + loss_featD_fake = self.criterionGAN(pred_fake_feat, False) + pred_real_feat=self.feat_discriminate(syn_feat) + loss_featD_real = self.criterionGAN(pred_real_feat, True) + + pred_fake_feat_G=self.feat_D.forward(real_old_feat) + loss_G_featD=self.criterionGAN(pred_fake_feat_G,True) + + + ##################################### + if self.opt.no_cgan: + # Fake Detection and Loss + pred_fake_pool = self.discriminate(None, fake_image, use_pool=True) + loss_D_fake = self.criterionGAN(pred_fake_pool, False) + + # Real Detection and Loss + pred_real = self.discriminate(None, real_image) + loss_D_real = self.criterionGAN(pred_real, True) + + # GAN loss (Fake Passability Loss) + pred_fake = self.netD.forward(fake_image) + loss_G_GAN = self.criterionGAN(pred_fake, True) + else: + # Fake Detection and Loss + pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True) + loss_D_fake = self.criterionGAN(pred_fake_pool, False) + + # Real Detection and Loss + pred_real = self.discriminate(input_label, real_image) + loss_D_real = self.criterionGAN(pred_real, True) + + # GAN loss (Fake Passability Loss) + pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1)) + loss_G_GAN = self.criterionGAN(pred_fake, True) + + loss_G_kl = torch.mean(torch.pow(hiddens, 2)) * self.opt.kl + + # GAN feature matching loss + loss_G_GAN_Feat = 0 + if not self.opt.no_ganFeat_loss: + feat_weights = 4.0 / (self.opt.n_layers_D + 1) + D_weights = 1.0 / self.opt.num_D + for i in range(self.opt.num_D): + for j in range(len(pred_fake[i]) - 1): + loss_G_GAN_Feat += D_weights * feat_weights * \ + self.criterionFeat(pred_fake[i][j], + pred_real[i][j].detach()) * self.opt.lambda_feat + + # VGG feature matching loss + loss_G_VGG = 0 + if not self.opt.no_vgg_loss: + loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat + + # Only return the fake_B image if necessary to save BW + return [self.loss_filter(loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_G_kl, loss_D_real, loss_D_fake,loss_G_featD, loss_featD_real, loss_featD_fake), + None if not infer else fake_image] + + def inference(self, label, inst, image=None, feat=None): + # Encode Inputs + image = Variable(image) if image is not None else None + input_label, inst_map, real_image, _ = self.encode_input(Variable(label), Variable(inst), image, infer=True) + + # Fake Generation + if self.use_features: + if self.opt.use_encoded_image: + # encode the real image to get feature map + feat_map = self.netE.forward(real_image, inst_map) + else: + # sample clusters from precomputed features + feat_map = self.sample_features(inst_map) + input_concat = torch.cat((input_label, feat_map), dim=1) + else: + input_concat = input_label + + if torch.__version__.startswith('0.4'): + with torch.no_grad(): + fake_image = self.netG.forward(input_concat) + else: + fake_image = self.netG.forward(input_concat) + return fake_image + + def sample_features(self, inst): + # read precomputed feature clusters + cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path) + features_clustered = np.load(cluster_path, encoding='latin1').item() + + # randomly sample from the feature clusters + inst_np = inst.cpu().numpy().astype(int) + feat_map = self.Tensor(inst.size()[0], self.opt.feat_num, inst.size()[2], inst.size()[3]) + for i in np.unique(inst_np): + label = i if i < 1000 else i // 1000 + if label in features_clustered: + feat = features_clustered[label] + cluster_idx = np.random.randint(0, feat.shape[0]) + + idx = (inst == int(i)).nonzero() + for k in range(self.opt.feat_num): + feat_map[idx[:, 0], idx[:, 1] + k, idx[:, 2], idx[:, 3]] = feat[cluster_idx, k] + if self.opt.data_type == 16: + feat_map = feat_map.half() + return feat_map + + def encode_features(self, image, inst): + image = Variable(image.cuda(), volatile=True) + feat_num = self.opt.feat_num + h, w = inst.size()[2], inst.size()[3] + block_num = 32 + feat_map = self.netE.forward(image, inst.cuda()) + inst_np = inst.cpu().numpy().astype(int) + feature = {} + for i in range(self.opt.label_nc): + feature[i] = np.zeros((0, feat_num + 1)) + for i in np.unique(inst_np): + label = i if i < 1000 else i // 1000 + idx = (inst == int(i)).nonzero() + num = idx.size()[0] + idx = idx[num // 2, :] + val = np.zeros((1, feat_num + 1)) + for k in range(feat_num): + val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0] + val[0, feat_num] = float(num) / (h * w // block_num) + feature[label] = np.append(feature[label], val, axis=0) + return feature + + def get_edges(self, t): + edge = torch.cuda.ByteTensor(t.size()).zero_() + edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1]) + edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1]) + edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) + edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) + if self.opt.data_type == 16: + return edge.half() + else: + return edge.float() + + def save(self, which_epoch): + self.save_network(self.netG, 'G', which_epoch, self.gpu_ids) + self.save_network(self.netD, 'D', which_epoch, self.gpu_ids) + self.save_network(self.feat_D,'featD',which_epoch,self.gpu_ids) + + self.save_optimizer(self.optimizer_G, "G", which_epoch) + self.save_optimizer(self.optimizer_D, "D", which_epoch) + self.save_optimizer(self.optimizer_featD,'featD',which_epoch) + + if self.gen_features: + self.save_network(self.netE, 'E', which_epoch, self.gpu_ids) + + def update_fixed_params(self): + + params = list(self.netG.parameters()) + if self.gen_features: + params += list(self.netE.parameters()) + self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) + if self.opt.verbose: + print('------------ Now also finetuning global generator -----------') + + def update_learning_rate(self): + lrd = self.opt.lr / self.opt.niter_decay + lr = self.old_lr - lrd + for param_group in self.optimizer_D.param_groups: + param_group['lr'] = lr + for param_group in self.optimizer_G.param_groups: + param_group['lr'] = lr + for param_group in self.optimizer_featD.param_groups: + param_group['lr'] = lr + if self.opt.verbose: + print('update learning rate: %f -> %f' % (self.old_lr, lr)) + self.old_lr = lr + + +class InferenceModel(Pix2PixHDModel): + def forward(self, inp): + label, inst = inp + return self.inference(label, inst) diff --git a/Global/options/__init__.py b/Global/options/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Global/options/base_options.py b/Global/options/base_options.py new file mode 100755 index 0000000000000000000000000000000000000000..b8ef551eb982a3b551f77090028304f40883a94a --- /dev/null +++ b/Global/options/base_options.py @@ -0,0 +1,373 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +import os +from util import util +import torch + + +class BaseOptions: + def __init__(self): + self.parser = argparse.ArgumentParser() + self.initialized = False + + def initialize(self): + # experiment specifics + self.parser.add_argument( + "--name", + type=str, + default="label2city", + help="name of the experiment. It decides where to store samples and models", + ) + self.parser.add_argument( + "--gpu_ids", type=str, default="0", help="gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU" + ) + self.parser.add_argument( + "--checkpoints_dir", type=str, default="./checkpoints", help="models are saved here" + ) ## note: to add this param when using philly + # self.parser.add_argument('--project_dir', type=str, default='./', help='the project is saved here') ################### This is necessary for philly + self.parser.add_argument( + "--outputs_dir", type=str, default="./outputs", help="models are saved here" + ) ## note: to add this param when using philly Please end with '/' + self.parser.add_argument("--model", type=str, default="pix2pixHD", help="which model to use") + self.parser.add_argument( + "--norm", type=str, default="instance", help="instance normalization or batch normalization" + ) + self.parser.add_argument("--use_dropout", action="store_true", help="use dropout for the generator") + self.parser.add_argument( + "--data_type", + default=32, + type=int, + choices=[8, 16, 32], + help="Supported data type i.e. 8, 16, 32 bit", + ) + self.parser.add_argument("--verbose", action="store_true", default=False, help="toggles verbose") + + # input/output sizes + self.parser.add_argument("--batchSize", type=int, default=1, help="input batch size") + self.parser.add_argument("--loadSize", type=int, default=1024, help="scale images to this size") + self.parser.add_argument("--fineSize", type=int, default=512, help="then crop to this size") + self.parser.add_argument("--label_nc", type=int, default=35, help="# of input label channels") + self.parser.add_argument("--input_nc", type=int, default=3, help="# of input image channels") + self.parser.add_argument("--output_nc", type=int, default=3, help="# of output image channels") + + # for setting inputs + self.parser.add_argument("--dataroot", type=str, default="./datasets/cityscapes/") + self.parser.add_argument( + "--resize_or_crop", + type=str, + default="scale_width", + help="scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]", + ) + self.parser.add_argument( + "--serial_batches", + action="store_true", + help="if true, takes images in order to make batches, otherwise takes them randomly", + ) + self.parser.add_argument( + "--no_flip", + action="store_true", + help="if specified, do not flip the images for data argumentation", + ) + self.parser.add_argument("--nThreads", default=2, type=int, help="# threads for loading data") + self.parser.add_argument( + "--max_dataset_size", + type=int, + default=float("inf"), + help="Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.", + ) + + # for displays + self.parser.add_argument("--display_winsize", type=int, default=512, help="display window size") + self.parser.add_argument( + "--tf_log", + action="store_true", + help="if specified, use tensorboard logging. Requires tensorflow installed", + ) + + # for generator + self.parser.add_argument("--netG", type=str, default="global", help="selects model to use for netG") + self.parser.add_argument("--ngf", type=int, default=64, help="# of gen filters in first conv layer") + self.parser.add_argument("--k_size", type=int, default=3, help="# kernel size conv layer") + self.parser.add_argument("--use_v2", action="store_true", help="use DCDCv2") + self.parser.add_argument("--mc", type=int, default=1024, help="# max channel") + self.parser.add_argument("--start_r", type=int, default=3, help="start layer to use resblock") + self.parser.add_argument( + "--n_downsample_global", type=int, default=4, help="number of downsampling layers in netG" + ) + self.parser.add_argument( + "--n_blocks_global", + type=int, + default=9, + help="number of residual blocks in the global generator network", + ) + self.parser.add_argument( + "--n_blocks_local", + type=int, + default=3, + help="number of residual blocks in the local enhancer network", + ) + self.parser.add_argument( + "--n_local_enhancers", type=int, default=1, help="number of local enhancers to use" + ) + self.parser.add_argument( + "--niter_fix_global", + type=int, + default=0, + help="number of epochs that we only train the outmost local enhancer", + ) + + self.parser.add_argument( + "--load_pretrain", + type=str, + default="", + help="load the pretrained model from the specified location", + ) + + # for instance-wise features + self.parser.add_argument( + "--no_instance", action="store_true", help="if specified, do *not* add instance map as input" + ) + self.parser.add_argument( + "--instance_feat", + action="store_true", + help="if specified, add encoded instance features as input", + ) + self.parser.add_argument( + "--label_feat", action="store_true", help="if specified, add encoded label features as input" + ) + self.parser.add_argument("--feat_num", type=int, default=3, help="vector length for encoded features") + self.parser.add_argument( + "--load_features", action="store_true", help="if specified, load precomputed feature maps" + ) + self.parser.add_argument( + "--n_downsample_E", type=int, default=4, help="# of downsampling layers in encoder" + ) + self.parser.add_argument( + "--nef", type=int, default=16, help="# of encoder filters in the first conv layer" + ) + self.parser.add_argument("--n_clusters", type=int, default=10, help="number of clusters for features") + + # diy + self.parser.add_argument("--self_gen", action="store_true", help="self generate") + self.parser.add_argument( + "--mapping_n_block", type=int, default=3, help="number of resblock in mapping" + ) + self.parser.add_argument("--map_mc", type=int, default=64, help="max channel of mapping") + self.parser.add_argument("--kl", type=float, default=0, help="KL Loss") + self.parser.add_argument( + "--load_pretrainA", + type=str, + default="", + help="load the pretrained model from the specified location", + ) + self.parser.add_argument( + "--load_pretrainB", + type=str, + default="", + help="load the pretrained model from the specified location", + ) + self.parser.add_argument("--feat_gan", action="store_true") + self.parser.add_argument("--no_cgan", action="store_true") + self.parser.add_argument("--map_unet", action="store_true") + self.parser.add_argument("--map_densenet", action="store_true") + self.parser.add_argument("--fcn", action="store_true") + self.parser.add_argument("--is_image", action="store_true", help="train image recon only pair data") + self.parser.add_argument("--label_unpair", action="store_true") + self.parser.add_argument("--mapping_unpair", action="store_true") + self.parser.add_argument("--unpair_w", type=float, default=1.0) + self.parser.add_argument("--pair_num", type=int, default=-1) + self.parser.add_argument("--Gan_w", type=float, default=1) + self.parser.add_argument("--feat_dim", type=int, default=-1) + self.parser.add_argument("--abalation_vae_len", type=int, default=-1) + + ######################### useless, just to cooperate with docker + self.parser.add_argument("--gpu", type=str) + self.parser.add_argument("--dataDir", type=str) + self.parser.add_argument("--modelDir", type=str) + self.parser.add_argument("--logDir", type=str) + self.parser.add_argument("--data_dir", type=str) + + self.parser.add_argument("--use_skip_model", action="store_true") + self.parser.add_argument("--use_segmentation_model", action="store_true") + + self.parser.add_argument("--spatio_size", type=int, default=64) + self.parser.add_argument("--test_random_crop", action="store_true") + ########################## + + self.parser.add_argument("--contain_scratch_L", action="store_true") + self.parser.add_argument( + "--mask_dilation", type=int, default=0 + ) ## Don't change the input, only dilation the mask + + self.parser.add_argument( + "--irregular_mask", type=str, default="", help="This is the root of the mask" + ) + self.parser.add_argument( + "--mapping_net_dilation", + type=int, + default=1, + help="This parameter is the dilation size of the translation net", + ) + + self.parser.add_argument( + "--VOC", type=str, default="VOC_RGB_JPEGImages.bigfile", help="The root of VOC dataset" + ) + + self.parser.add_argument("--non_local", type=str, default="", help="which non_local setting") + self.parser.add_argument( + "--NL_fusion_method", + type=str, + default="add", + help="how to fuse the origin feature and nl feature", + ) + self.parser.add_argument( + "--NL_use_mask", action="store_true", help="If use mask while using Non-local mapping model" + ) + self.parser.add_argument( + "--correlation_renormalize", + action="store_true", + help="Since after mask out the correlation matrix(which is softmaxed), the sum is not 1 any more, enable this param to re-weight", + ) + + self.parser.add_argument("--Smooth_L1", action="store_true", help="Use L1 Loss in image level") + + self.parser.add_argument( + "--face_restore_setting", type=int, default=1, help="This is for the aligned face restoration" + ) + self.parser.add_argument("--face_clean_url", type=str, default="") + self.parser.add_argument("--syn_input_url", type=str, default="") + self.parser.add_argument("--syn_gt_url", type=str, default="") + + self.parser.add_argument( + "--test_on_synthetic", + action="store_true", + help="If you want to test on the synthetic data, enable this parameter", + ) + + self.parser.add_argument("--use_SN", action="store_true", help="Add SN to every parametric layer") + + self.parser.add_argument( + "--use_two_stage_mapping", action="store_true", help="choose the model which uses two stage" + ) + + self.parser.add_argument("--L1_weight", type=float, default=10.0) + self.parser.add_argument("--softmax_temperature", type=float, default=1.0) + self.parser.add_argument( + "--patch_similarity", + action="store_true", + help="Enable this denotes using 3*3 patch to calculate similarity", + ) + self.parser.add_argument( + "--use_self", + action="store_true", + help="Enable this denotes that while constructing the new feature maps, using original feature (diagonal == 1)", + ) + + self.parser.add_argument("--use_own_dataset", action="store_true") + + self.parser.add_argument( + "--test_hole_two_folders", + action="store_true", + help="Enable this parameter means test the restoration with inpainting given twp folders which are mask and old respectively", + ) + + self.parser.add_argument( + "--no_hole", + action="store_true", + help="While test the full_model on non_scratch data, do not add random mask into the real old photos", + ) ## Only for testing + self.parser.add_argument( + "--random_hole", + action="store_true", + help="While training the full model, 50% probability add hole", + ) + + self.parser.add_argument("--NL_res", action="store_true", help="NL+Resdual Block") + + self.parser.add_argument("--image_L1", action="store_true", help="Image level loss: L1") + self.parser.add_argument( + "--hole_image_no_mask", + action="store_true", + help="while testing, give hole image but not give the mask", + ) + + self.parser.add_argument( + "--down_sample_degradation", + action="store_true", + help="down_sample the image only, corresponds to [down_sample_face]", + ) + + self.parser.add_argument( + "--norm_G", type=str, default="spectralinstance", help="The norm type of Generator" + ) + self.parser.add_argument( + "--init_G", + type=str, + default="xavier", + help="normal|xavier|xavier_uniform|kaiming|orthogonal|none", + ) + + self.parser.add_argument("--use_new_G", action="store_true") + self.parser.add_argument("--use_new_D", action="store_true") + + self.parser.add_argument( + "--only_voc", action="store_true", help="test the trianed celebA face model using VOC face" + ) + + self.parser.add_argument( + "--cosin_similarity", + action="store_true", + help="For non-local, using cosin to calculate the similarity", + ) + + self.parser.add_argument( + "--downsample_mode", + type=str, + default="nearest", + help="For partial non-local, choose how to downsample the mask", + ) + + self.parser.add_argument("--mapping_exp",type=int,default=0,help='Default 0: original PNL|1: Multi-Scale Patch Attention') + self.parser.add_argument("--inference_optimize",action='store_true',help='optimize the memory cost') + + + self.initialized = True + + def parse(self, save=True): + if not self.initialized: + self.initialize() + self.opt = self.parser.parse_args() + self.opt.isTrain = self.isTrain # train or test + + str_ids = self.opt.gpu_ids.split(",") + self.opt.gpu_ids = [] + for str_id in str_ids: + int_id = int(str_id) + if int_id >= 0: + self.opt.gpu_ids.append(int_id) + + # set gpu ids + if len(self.opt.gpu_ids) > 0: + # pass + torch.cuda.set_device(self.opt.gpu_ids[0]) + + args = vars(self.opt) + + # print('------------ Options -------------') + # for k, v in sorted(args.items()): + # print('%s: %s' % (str(k), str(v))) + # print('-------------- End ----------------') + + # save to the disk + expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) + util.mkdirs(expr_dir) + if save and not self.opt.continue_train: + file_name = os.path.join(expr_dir, "opt.txt") + with open(file_name, "wt") as opt_file: + opt_file.write("------------ Options -------------\n") + for k, v in sorted(args.items()): + opt_file.write("%s: %s\n" % (str(k), str(v))) + opt_file.write("-------------- End ----------------\n") + return self.opt diff --git a/Global/options/test_options.py b/Global/options/test_options.py new file mode 100755 index 0000000000000000000000000000000000000000..67e2e3a720cf7f9e540b09b64242197cdb712b57 --- /dev/null +++ b/Global/options/test_options.py @@ -0,0 +1,100 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .base_options import BaseOptions + + +class TestOptions(BaseOptions): + def initialize(self): + BaseOptions.initialize(self) + self.parser.add_argument("--ntest", type=int, default=float("inf"), help="# of test examples.") + self.parser.add_argument("--results_dir", type=str, default="./results/", help="saves results here.") + self.parser.add_argument( + "--aspect_ratio", type=float, default=1.0, help="aspect ratio of result images" + ) + self.parser.add_argument("--phase", type=str, default="test", help="train, val, test, etc") + self.parser.add_argument( + "--which_epoch", + type=str, + default="latest", + help="which epoch to load? set to latest to use latest cached model", + ) + self.parser.add_argument("--how_many", type=int, default=50, help="how many test images to run") + self.parser.add_argument( + "--cluster_path", + type=str, + default="features_clustered_010.npy", + help="the path for clustered results of encoded features", + ) + self.parser.add_argument( + "--use_encoded_image", + action="store_true", + help="if specified, encode the real image to get the feature map", + ) + self.parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file") + self.parser.add_argument("--engine", type=str, help="run serialized TRT engine") + self.parser.add_argument("--onnx", type=str, help="run ONNX model via TRT") + self.parser.add_argument( + "--start_epoch", + type=int, + default=-1, + help="write the start_epoch of iter.txt into this parameter", + ) + + self.parser.add_argument("--test_dataset", type=str, default="Real_RGB_old.bigfile") + self.parser.add_argument( + "--no_degradation", + action="store_true", + help="when train the mapping, enable this parameter --> no degradation will be added into clean image", + ) + self.parser.add_argument( + "--no_load_VAE", + action="store_true", + help="when train the mapping, enable this parameter --> random initialize the encoder an decoder", + ) + self.parser.add_argument( + "--use_v2_degradation", + action="store_true", + help="enable this parameter --> 4 kinds of degradations will be used to synthesize corruption", + ) + self.parser.add_argument("--use_vae_which_epoch", type=str, default="latest") + self.isTrain = False + + self.parser.add_argument("--generate_pair", action="store_true") + + self.parser.add_argument("--multi_scale_test", type=float, default=0.5) + self.parser.add_argument("--multi_scale_threshold", type=float, default=0.5) + self.parser.add_argument( + "--mask_need_scale", + action="store_true", + help="enable this param meas that the pixel range of mask is 0-255", + ) + self.parser.add_argument("--scale_num", type=int, default=1) + + self.parser.add_argument( + "--save_feature_url", type=str, default="", help="While extracting the features, where to put" + ) + + self.parser.add_argument( + "--test_input", type=str, default="", help="A directory or a root of bigfile" + ) + self.parser.add_argument("--test_mask", type=str, default="", help="A directory or a root of bigfile") + self.parser.add_argument("--test_gt", type=str, default="", help="A directory or a root of bigfile") + + self.parser.add_argument( + "--scale_input", action="store_true", help="While testing, choose to scale the input firstly" + ) + + self.parser.add_argument( + "--save_feature_name", type=str, default="features.json", help="The name of saved features" + ) + self.parser.add_argument( + "--test_rgb_old_wo_scratch", action="store_true", help="Same setting with origin test" + ) + + self.parser.add_argument("--test_mode", type=str, default="Crop", help="Scale|Full|Crop") + self.parser.add_argument("--Quality_restore", action="store_true", help="For RGB images") + self.parser.add_argument( + "--Scratch_and_Quality_restore", action="store_true", help="For scratched images" + ) + self.parser.add_argument("--HR", action='store_true',help='Large input size with scratches') diff --git a/Global/options/train_options.py b/Global/options/train_options.py new file mode 100644 index 0000000000000000000000000000000000000000..6cc3296657043568a3a961d793f2c69f568bab1a --- /dev/null +++ b/Global/options/train_options.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .base_options import BaseOptions + +class TrainOptions(BaseOptions): + def initialize(self): + BaseOptions.initialize(self) + # for displays + self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') + self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') + self.parser.add_argument('--save_latest_freq', type=int, default=10000, help='frequency of saving the latest results') + self.parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') + self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') + self.parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration') + + # for training + self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') + # self.parser.add_argument('--load_pretrain', type=str, default='', help='load the pretrained model from the specified location') + self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') + self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') + self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') + self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') + self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') + self.parser.add_argument('--training_dataset',type=str,default='',help='training use which dataset') + + # for discriminators + self.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use') + self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers') + self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') + self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss') + self.parser.add_argument('--l2_feat', type=float, help='weight for feature mapping loss') + self.parser.add_argument('--use_l1_feat', action='store_true', help='use l1 for feat mapping') + self.parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss') + self.parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss') + self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN') + self.parser.add_argument('--gan_type', type=str, default='lsgan', help='Choose the loss type of GAN') + self.parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images') + self.parser.add_argument('--norm_D',type=str, default='spectralinstance', help='instance normalization or batch normalization') + self.parser.add_argument('--init_D',type=str,default='xavier',help='normal|xavier|xavier_uniform|kaiming|orthogonal|none') + + self.parser.add_argument('--no_TTUR',action='store_true',help='No TTUR') + + self.parser.add_argument('--start_epoch',type=int,default=-1,help='write the start_epoch of iter.txt into this parameter') + self.parser.add_argument('--no_degradation',action='store_true',help='when train the mapping, enable this parameter --> no degradation will be added into clean image') + self.parser.add_argument('--no_load_VAE',action='store_true',help='when train the mapping, enable this parameter --> random initialize the encoder an decoder') + self.parser.add_argument('--use_v2_degradation',action='store_true',help='enable this parameter --> 4 kinds of degradations will be used to synthesize corruption') + self.parser.add_argument('--use_vae_which_epoch',type=str,default='200') + + + self.parser.add_argument('--use_focal_loss',action='store_true') + + self.parser.add_argument('--mask_need_scale',action='store_true',help='enable this param means that the pixel range of mask is 0-255') + self.parser.add_argument('--positive_weight',type=float,default=1.0,help='(For scratch detection) Since the scratch number is less, and we use a weight strategy. This parameter means that we want to decrease the weight.') + + self.parser.add_argument('--no_update_lr',action='store_true',help='use this means we do not update the LR while training') + + + self.isTrain = True diff --git a/Global/test.py b/Global/test.py new file mode 100644 index 0000000000000000000000000000000000000000..01264ed2069de188313c5cef0bbfb9fd14a638cf --- /dev/null +++ b/Global/test.py @@ -0,0 +1,192 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os +from collections import OrderedDict +from torch.autograd import Variable +from options.test_options import TestOptions +from models.models import create_model +from models.mapping_model import Pix2PixHDModel_Mapping +import util.util as util +from PIL import Image +import torch +import torchvision.utils as vutils +import torchvision.transforms as transforms +import numpy as np +import cv2 + +def data_transforms(img, method=Image.BILINEAR, scale=False): + + ow, oh = img.size + pw, ph = ow, oh + if scale == True: + if ow < oh: + ow = 256 + oh = ph / pw * 256 + else: + oh = 256 + ow = pw / ph * 256 + + h = int(round(oh / 4) * 4) + w = int(round(ow / 4) * 4) + + if (h == ph) and (w == pw): + return img + + return img.resize((w, h), method) + + +def data_transforms_rgb_old(img): + w, h = img.size + A = img + if w < 256 or h < 256: + A = transforms.Scale(256, Image.BILINEAR)(img) + return transforms.CenterCrop(256)(A) + + +def irregular_hole_synthesize(img, mask): + + img_np = np.array(img).astype("uint8") + mask_np = np.array(mask).astype("uint8") + mask_np = mask_np / 255 + img_new = img_np * (1 - mask_np) + mask_np * 255 + + hole_img = Image.fromarray(img_new.astype("uint8")).convert("RGB") + + return hole_img + + +def parameter_set(opt): + ## Default parameters + opt.serial_batches = True # no shuffle + opt.no_flip = True # no flip + opt.label_nc = 0 + opt.n_downsample_global = 3 + opt.mc = 64 + opt.k_size = 4 + opt.start_r = 1 + opt.mapping_n_block = 6 + opt.map_mc = 512 + opt.no_instance = True + opt.checkpoints_dir = "./checkpoints/restoration" + ## + + if opt.Quality_restore: + opt.name = "mapping_quality" + opt.load_pretrainA = os.path.join(opt.checkpoints_dir, "VAE_A_quality") + opt.load_pretrainB = os.path.join(opt.checkpoints_dir, "VAE_B_quality") + if opt.Scratch_and_Quality_restore: + opt.NL_res = True + opt.use_SN = True + opt.correlation_renormalize = True + opt.NL_use_mask = True + opt.NL_fusion_method = "combine" + opt.non_local = "Setting_42" + opt.name = "mapping_scratch" + opt.load_pretrainA = os.path.join(opt.checkpoints_dir, "VAE_A_quality") + opt.load_pretrainB = os.path.join(opt.checkpoints_dir, "VAE_B_scratch") + if opt.HR: + opt.mapping_exp = 1 + opt.inference_optimize = True + opt.mask_dilation = 3 + opt.name = "mapping_Patch_Attention" + + +if __name__ == "__main__": + + opt = TestOptions().parse(save=False) + parameter_set(opt) + + model = Pix2PixHDModel_Mapping() + + model.initialize(opt) + model.eval() + + if not os.path.exists(opt.outputs_dir + "/" + "input_image"): + os.makedirs(opt.outputs_dir + "/" + "input_image") + if not os.path.exists(opt.outputs_dir + "/" + "restored_image"): + os.makedirs(opt.outputs_dir + "/" + "restored_image") + if not os.path.exists(opt.outputs_dir + "/" + "origin"): + os.makedirs(opt.outputs_dir + "/" + "origin") + + dataset_size = 0 + + input_loader = os.listdir(opt.test_input) + dataset_size = len(input_loader) + input_loader.sort() + + if opt.test_mask != "": + mask_loader = os.listdir(opt.test_mask) + dataset_size = len(os.listdir(opt.test_mask)) + mask_loader.sort() + + img_transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) + mask_transform = transforms.ToTensor() + + for i in range(dataset_size): + + input_name = input_loader[i] + input_file = os.path.join(opt.test_input, input_name) + if not os.path.isfile(input_file): + print("Skipping non-file %s" % input_name) + continue + input = Image.open(input_file).convert("RGB") + + print("Now you are processing %s" % (input_name)) + + if opt.NL_use_mask: + mask_name = mask_loader[i] + mask = Image.open(os.path.join(opt.test_mask, mask_name)).convert("RGB") + if opt.mask_dilation != 0: + kernel = np.ones((3,3),np.uint8) + mask = np.array(mask) + mask = cv2.dilate(mask,kernel,iterations = opt.mask_dilation) + mask = Image.fromarray(mask.astype('uint8')) + origin = input + input = irregular_hole_synthesize(input, mask) + mask = mask_transform(mask) + mask = mask[:1, :, :] ## Convert to single channel + mask = mask.unsqueeze(0) + input = img_transform(input) + input = input.unsqueeze(0) + else: + if opt.test_mode == "Scale": + input = data_transforms(input, scale=True) + if opt.test_mode == "Full": + input = data_transforms(input, scale=False) + if opt.test_mode == "Crop": + input = data_transforms_rgb_old(input) + origin = input + input = img_transform(input) + input = input.unsqueeze(0) + mask = torch.zeros_like(input) + ### Necessary input + + try: + with torch.no_grad(): + generated = model.inference(input, mask) + except Exception as ex: + print("Skip %s due to an error:\n%s" % (input_name, str(ex))) + continue + + if input_name.endswith(".jpg"): + input_name = input_name[:-4] + ".png" + + image_grid = vutils.save_image( + (input + 1.0) / 2.0, + opt.outputs_dir + "/input_image/" + input_name, + nrow=1, + padding=0, + normalize=True, + ) + image_grid = vutils.save_image( + (generated.data.cpu() + 1.0) / 2.0, + opt.outputs_dir + "/restored_image/" + input_name, + nrow=1, + padding=0, + normalize=True, + ) + + origin.save(opt.outputs_dir + "/origin/" + input_name) \ No newline at end of file diff --git a/Global/train_domain_A.py b/Global/train_domain_A.py new file mode 100644 index 0000000000000000000000000000000000000000..45004938349d674227b2fac3ad9644370c9eda30 --- /dev/null +++ b/Global/train_domain_A.py @@ -0,0 +1,147 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import time +from collections import OrderedDict +from options.train_options import TrainOptions +from data.data_loader import CreateDataLoader +from models.models import create_da_model +import util.util as util +from util.visualizer import Visualizer +import os +import numpy as np +import torch +import torchvision.utils as vutils +from torch.autograd import Variable + +opt = TrainOptions().parse() + +if opt.debug: + opt.display_freq = 1 + opt.print_freq = 1 + opt.niter = 1 + opt.niter_decay = 0 + opt.max_dataset_size = 10 + +data_loader = CreateDataLoader(opt) +dataset = data_loader.load_data() +dataset_size = len(dataset) * opt.batchSize +print('#training images = %d' % dataset_size) + +path = os.path.join(opt.checkpoints_dir, opt.name, 'model.txt') +visualizer = Visualizer(opt) + +iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') +if opt.continue_train: + try: + start_epoch, epoch_iter = np.loadtxt(iter_path, delimiter=',', dtype=int) + except: + start_epoch, epoch_iter = 1, 0 + visualizer.print_save('Resuming from epoch %d at iteration %d' % (start_epoch - 1, epoch_iter)) +else: + start_epoch, epoch_iter = 1, 0 + +# opt.which_epoch=start_epoch-1 +model = create_da_model(opt) +fd = open(path, 'w') +fd.write(str(model.module.netG)) +fd.write(str(model.module.netD)) +fd.close() + +total_steps = (start_epoch - 1) * dataset_size + epoch_iter + +display_delta = total_steps % opt.display_freq +print_delta = total_steps % opt.print_freq +save_delta = total_steps % opt.save_latest_freq + +for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1): + epoch_start_time = time.time() + if epoch != start_epoch: + epoch_iter = epoch_iter % dataset_size + for i, data in enumerate(dataset, start=epoch_iter): + iter_start_time = time.time() + total_steps += opt.batchSize + epoch_iter += opt.batchSize + + # whether to collect output images + save_fake = total_steps % opt.display_freq == display_delta + + ############## Forward Pass ###################### + losses, generated = model(Variable(data['label']), Variable(data['inst']), + Variable(data['image']), Variable(data['feat']), infer=save_fake) + + # sum per device losses + losses = [torch.mean(x) if not isinstance(x, int) else x for x in losses] + loss_dict = dict(zip(model.module.loss_names, losses)) + + # calculate final loss scalar + loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5 + loss_featD=(loss_dict['featD_fake'] + loss_dict['featD_real']) * 0.5 + loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat', 0) + loss_dict.get('G_VGG', 0) + loss_dict['G_KL'] + loss_dict['G_featD'] + + ############### Backward Pass #################### + # update generator weights + model.module.optimizer_G.zero_grad() + loss_G.backward() + model.module.optimizer_G.step() + + # update discriminator weights + model.module.optimizer_D.zero_grad() + loss_D.backward() + model.module.optimizer_D.step() + + model.module.optimizer_featD.zero_grad() + loss_featD.backward() + model.module.optimizer_featD.step() + + # call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"]) + + ############## Display results and errors ########## + ### print out errors + if total_steps % opt.print_freq == print_delta: + errors = {k: v.data if not isinstance(v, int) else v for k, v in loss_dict.items()} + t = (time.time() - iter_start_time) / opt.batchSize + visualizer.print_current_errors(epoch, epoch_iter, errors, t, model.module.old_lr) + visualizer.plot_current_errors(errors, total_steps) + + ### display output images + if save_fake: + + if not os.path.exists(opt.outputs_dir + opt.name): + os.makedirs(opt.outputs_dir + opt.name) + imgs_num = data['label'].shape[0] + imgs = torch.cat((data['label'], generated.data.cpu(), data['image']), 0) + + imgs = (imgs + 1.) / 2.0 + + try: + image_grid = vutils.save_image(imgs, opt.outputs_dir + opt.name + '/' + str(epoch) + '_' + str( + total_steps) + '.png', + nrow=imgs_num, padding=0, normalize=True) + except OSError as err: + print(err) + + + if epoch_iter >= dataset_size: + break + + # end of epoch + iter_end_time = time.time() + print('End of epoch %d / %d \t Time Taken: %d sec' % + (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) + + ### save model for this epoch + if epoch % opt.save_epoch_freq == 0: + print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) + model.module.save('latest') + model.module.save(epoch) + np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d') + + ### instead of only training the local enhancer, train the entire network after certain iterations + if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global): + model.module.update_fixed_params() + + ### linearly decay learning rate after certain iterations + if epoch > opt.niter: + model.module.update_learning_rate() + diff --git a/Global/train_domain_B.py b/Global/train_domain_B.py new file mode 100644 index 0000000000000000000000000000000000000000..4659b047290b3f5e95a483713b3ad9984a4f94f5 --- /dev/null +++ b/Global/train_domain_B.py @@ -0,0 +1,145 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import time +from collections import OrderedDict +from options.train_options import TrainOptions +from data.data_loader import CreateDataLoader +from models.models import create_model +import util.util as util +from util.visualizer import Visualizer +import os +import numpy as np +import torch +import torchvision.utils as vutils +from torch.autograd import Variable +import random + + +opt = TrainOptions().parse() + +if opt.debug: + opt.display_freq = 1 + opt.print_freq = 1 + opt.niter = 1 + opt.niter_decay = 0 + opt.max_dataset_size = 10 + +data_loader = CreateDataLoader(opt) +dataset = data_loader.load_data() +dataset_size = len(dataset) * opt.batchSize +print('#training images = %d' % dataset_size) + +path = os.path.join(opt.checkpoints_dir, opt.name, 'model.txt') +visualizer = Visualizer(opt) + + +iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') +if opt.continue_train: + try: + start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int) + except: + start_epoch, epoch_iter = 1, 0 + visualizer.print_save('Resuming from epoch %d at iteration %d' % (start_epoch-1, epoch_iter)) +else: + start_epoch, epoch_iter = 1, 0 + +# opt.which_epoch=start_epoch-1 +model = create_model(opt) +fd = open(path, 'w') +fd.write(str(model.module.netG)) +fd.write(str(model.module.netD)) +fd.close() + +total_steps = (start_epoch-1) * dataset_size + epoch_iter + +display_delta = total_steps % opt.display_freq +print_delta = total_steps % opt.print_freq +save_delta = total_steps % opt.save_latest_freq + +for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1): + epoch_start_time = time.time() + if epoch != start_epoch: + epoch_iter = epoch_iter % dataset_size + for i, data in enumerate(dataset, start=epoch_iter): + iter_start_time = time.time() + total_steps += opt.batchSize + epoch_iter += opt.batchSize + + # whether to collect output images + save_fake = total_steps % opt.display_freq == display_delta + + ############## Forward Pass ###################### + losses, generated = model(Variable(data['label']), Variable(data['inst']), + Variable(data['image']), Variable(data['feat']), infer=save_fake) + + # sum per device losses + losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ] + loss_dict = dict(zip(model.module.loss_names, losses)) + + + # calculate final loss scalar + loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5 + loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat',0) + loss_dict.get('G_VGG',0) + loss_dict['G_KL'] + loss_dict.get('Smooth_L1',0)*opt.L1_weight + + + ############### Backward Pass #################### + # update generator weights + model.module.optimizer_G.zero_grad() + loss_G.backward() + model.module.optimizer_G.step() + + # update discriminator weights + model.module.optimizer_D.zero_grad() + loss_D.backward() + model.module.optimizer_D.step() + + #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"]) + + ############## Display results and errors ########## + ### print out errors + if total_steps % opt.print_freq == print_delta: + errors = {k: v.data if not isinstance(v, int) else v for k, v in loss_dict.items()} + t = (time.time() - iter_start_time) / opt.batchSize + visualizer.print_current_errors(epoch, epoch_iter, errors, t, model.module.old_lr) + visualizer.plot_current_errors(errors, total_steps) + + ### display output images + if save_fake: + + if not os.path.exists(opt.outputs_dir + opt.name): + os.makedirs(opt.outputs_dir + opt.name) + imgs_num = 5 + imgs = torch.cat((data['label'][:imgs_num], generated.data.cpu()[:imgs_num], data['image'][:imgs_num]), 0) + + imgs = (imgs + 1.) / 2.0 + + try: + image_grid = vutils.save_image(imgs, opt.outputs_dir + opt.name + '/' + str(epoch) + '_' + str(total_steps) + '.png', + nrow=imgs_num, padding=0, normalize=True) + except OSError as err: + print(err) + + if epoch_iter >= dataset_size: + break + + # end of epoch + iter_end_time = time.time() + print('End of epoch %d / %d \t Time Taken: %d sec' % + (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) + + ### save model for this epoch + if epoch % opt.save_epoch_freq == 0: + print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) + model.module.save('latest') + model.module.save(epoch) + np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d') + + ### instead of only training the local enhancer, train the entire network after certain iterations + if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global): + model.module.update_fixed_params() + + ### linearly decay learning rate after certain iterations + if epoch > opt.niter: + model.module.update_learning_rate() + diff --git a/Global/train_mapping.py b/Global/train_mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..ffff4a5de7622e831989e8cb0daa694325a345b5 --- /dev/null +++ b/Global/train_mapping.py @@ -0,0 +1,162 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import time +from collections import OrderedDict +from options.train_options import TrainOptions +from data.data_loader import CreateDataLoader +from models.mapping_model import Pix2PixHDModel_Mapping +import util.util as util +from util.visualizer import Visualizer +import os +import numpy as np +import torch +import torchvision.utils as vutils +from torch.autograd import Variable +import datetime +import random + + + +opt = TrainOptions().parse() +visualizer = Visualizer(opt) +iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') +if opt.continue_train: + try: + start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int) + except: + start_epoch, epoch_iter = 1, 0 + visualizer.print_save('Resuming from epoch %d at iteration %d' % (start_epoch-1, epoch_iter)) +else: + start_epoch, epoch_iter = 1, 0 + +if opt.which_epoch != "latest": + start_epoch=int(opt.which_epoch) + visualizer.print_save('Notice : Resuming from epoch %d at iteration %d' % (start_epoch - 1, epoch_iter)) + +opt.start_epoch=start_epoch +### temp for continue train unfixed decoder + +data_loader = CreateDataLoader(opt) +dataset = data_loader.load_data() +dataset_size = len(dataset) * opt.batchSize +print('#training images = %d' % dataset_size) + + +model = Pix2PixHDModel_Mapping() +model.initialize(opt) + +path = os.path.join(opt.checkpoints_dir, opt.name, 'model.txt') +fd = open(path, 'w') + +if opt.use_skip_model: + fd.write(str(model.mapping_net)) + fd.close() +else: + fd.write(str(model.netG_A)) + fd.write(str(model.mapping_net)) + fd.close() + +if opt.isTrain and len(opt.gpu_ids) > 1: + model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) + + + +total_steps = (start_epoch-1) * dataset_size + epoch_iter + +display_delta = total_steps % opt.display_freq +print_delta = total_steps % opt.print_freq +save_delta = total_steps % opt.save_latest_freq +### used for recovering training + +for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1): + epoch_s_t=datetime.datetime.now() + epoch_start_time = time.time() + if epoch != start_epoch: + epoch_iter = epoch_iter % dataset_size + for i, data in enumerate(dataset, start=epoch_iter): + iter_start_time = time.time() + total_steps += opt.batchSize + epoch_iter += opt.batchSize + + # whether to collect output images + save_fake = total_steps % opt.display_freq == display_delta + + ############## Forward Pass ###################### + #print(pair) + losses, generated = model(Variable(data['label']), Variable(data['inst']), + Variable(data['image']), Variable(data['feat']), infer=save_fake) + + # sum per device losses + losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ] + loss_dict = dict(zip(model.module.loss_names, losses)) + + # calculate final loss scalar + loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5 + loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat',0) + loss_dict.get('G_VGG',0) + loss_dict.get('G_Feat_L2', 0) +loss_dict.get('Smooth_L1', 0)+loss_dict.get('G_Feat_L2_Stage_1',0) + #loss_G = loss_dict['G_Feat_L2'] + + ############### Backward Pass #################### + # update generator weights + model.module.optimizer_mapping.zero_grad() + loss_G.backward() + model.module.optimizer_mapping.step() + + # update discriminator weights + model.module.optimizer_D.zero_grad() + loss_D.backward() + model.module.optimizer_D.step() + + ############## Display results and errors ########## + ### print out errors + if i == 0 or total_steps % opt.print_freq == print_delta: + errors = {k: v.data if not isinstance(v, int) else v for k, v in loss_dict.items()} + t = (time.time() - iter_start_time) / opt.batchSize + visualizer.print_current_errors(epoch, epoch_iter, errors, t,model.module.old_lr) + visualizer.plot_current_errors(errors, total_steps) + + ### display output images + if save_fake: + + if not os.path.exists(opt.outputs_dir + opt.name): + os.makedirs(opt.outputs_dir + opt.name) + + imgs_num = 5 + if opt.NL_use_mask: + mask=data['inst'][:imgs_num] + mask=mask.repeat(1,3,1,1) + imgs = torch.cat((data['label'][:imgs_num], mask,generated.data.cpu()[:imgs_num], data['image'][:imgs_num]), 0) + else: + imgs = torch.cat((data['label'][:imgs_num], generated.data.cpu()[:imgs_num], data['image'][:imgs_num]), 0) + + imgs=(imgs+1.)/2.0 ## de-normalize + + try: + image_grid = vutils.save_image(imgs, opt.outputs_dir + opt.name + '/' + str(epoch) + '_' + str(total_steps) + '.png', + nrow=imgs_num, padding=0, normalize=True) + except OSError as err: + print(err) + + if epoch_iter >= dataset_size: + break + + # end of epoch + epoch_e_t=datetime.datetime.now() + iter_end_time = time.time() + print('End of epoch %d / %d \t Time Taken: %s' % + (epoch, opt.niter + opt.niter_decay, str(epoch_e_t-epoch_s_t))) + + ### save model for this epoch + if epoch % opt.save_epoch_freq == 0: + print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) + model.module.save('latest') + model.module.save(epoch) + np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d') + + ### instead of only training the local enhancer, train the entire network after certain iterations + if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global): + model.module.update_fixed_params() + + ### linearly decay learning rate after certain iterations + if epoch > opt.niter: + model.module.update_learning_rate() \ No newline at end of file diff --git a/Global/util/__init__.py b/Global/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Global/util/image_pool.py b/Global/util/image_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..1e7846e7c203f5a3d3f8d7187f906990762396fa --- /dev/null +++ b/Global/util/image_pool.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import random +import torch +from torch.autograd import Variable + + +class ImagePool: + def __init__(self, pool_size): + self.pool_size = pool_size + if self.pool_size > 0: + self.num_imgs = 0 + self.images = [] + + def query(self, images): + if self.pool_size == 0: + return images + return_images = [] + for image in images.data: + image = torch.unsqueeze(image, 0) + if self.num_imgs < self.pool_size: + self.num_imgs = self.num_imgs + 1 + self.images.append(image) + return_images.append(image) + else: + p = random.uniform(0, 1) + if p > 0.5: + random_id = random.randint(0, self.pool_size - 1) + tmp = self.images[random_id].clone() + self.images[random_id] = image + return_images.append(tmp) + else: + return_images.append(image) + return_images = Variable(torch.cat(return_images, 0)) + return return_images diff --git a/Global/util/util.py b/Global/util/util.py new file mode 100644 index 0000000000000000000000000000000000000000..b1369c3b568548a1c21d3412aef5fd35c9b0c5be --- /dev/null +++ b/Global/util/util.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import print_function +import torch +import numpy as np +from PIL import Image +import numpy as np +import os +import torch.nn as nn + +# Converts a Tensor into a Numpy array +# |imtype|: the desired type of the converted numpy array +def tensor2im(image_tensor, imtype=np.uint8, normalize=True): + if isinstance(image_tensor, list): + image_numpy = [] + for i in range(len(image_tensor)): + image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) + return image_numpy + image_numpy = image_tensor.cpu().float().numpy() + if normalize: + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 + else: + image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 + image_numpy = np.clip(image_numpy, 0, 255) + if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3: + image_numpy = image_numpy[:, :, 0] + return image_numpy.astype(imtype) + + +# Converts a one-hot tensor into a colorful label map +def tensor2label(label_tensor, n_label, imtype=np.uint8): + if n_label == 0: + return tensor2im(label_tensor, imtype) + label_tensor = label_tensor.cpu().float() + if label_tensor.size()[0] > 1: + label_tensor = label_tensor.max(0, keepdim=True)[1] + label_tensor = Colorize(n_label)(label_tensor) + label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0)) + return label_numpy.astype(imtype) + + +def save_image(image_numpy, image_path): + image_pil = Image.fromarray(image_numpy) + image_pil.save(image_path) + + +def mkdirs(paths): + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) diff --git a/Global/util/visualizer.py b/Global/util/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..1a88df203aa95750ba911c77b32f6234863b8e79 --- /dev/null +++ b/Global/util/visualizer.py @@ -0,0 +1,143 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import numpy as np +import os +import ntpath +import time +from . import util +#from . import html +import scipy.misc +try: + from StringIO import StringIO # Python 2.7 +except ImportError: + from io import BytesIO # Python 3.x + +class Visualizer(): + def __init__(self, opt): + # self.opt = opt + self.tf_log = opt.tf_log + self.use_html = opt.isTrain and not opt.no_html + self.win_size = opt.display_winsize + self.name = opt.name + if self.tf_log: + import tensorflow as tf + self.tf = tf + self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs') + self.writer = tf.summary.FileWriter(self.log_dir) + + if self.use_html: + self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') + self.img_dir = os.path.join(self.web_dir, 'images') + print('create web directory %s...' % self.web_dir) + util.mkdirs([self.web_dir, self.img_dir]) + self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + with open(self.log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Training Loss (%s) ================\n' % now) + + # |visuals|: dictionary of images to display or save + def display_current_results(self, visuals, epoch, step): + if self.tf_log: # show images in tensorboard output + img_summaries = [] + for label, image_numpy in visuals.items(): + # Write the image to a string + try: + s = StringIO() + except: + s = BytesIO() + scipy.misc.toimage(image_numpy).save(s, format="jpeg") + # Create an Image object + img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1]) + # Create a Summary value + img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum)) + + # Create and write Summary + summary = self.tf.Summary(value=img_summaries) + self.writer.add_summary(summary, step) + + if self.use_html: # save images to a html file + for label, image_numpy in visuals.items(): + if isinstance(image_numpy, list): + for i in range(len(image_numpy)): + img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%d.jpg' % (epoch, label, i)) + util.save_image(image_numpy[i], img_path) + else: + img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.jpg' % (epoch, label)) + util.save_image(image_numpy, img_path) + + # update website + webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=30) + for n in range(epoch, 0, -1): + webpage.add_header('epoch [%d]' % n) + ims = [] + txts = [] + links = [] + + for label, image_numpy in visuals.items(): + if isinstance(image_numpy, list): + for i in range(len(image_numpy)): + img_path = 'epoch%.3d_%s_%d.jpg' % (n, label, i) + ims.append(img_path) + txts.append(label+str(i)) + links.append(img_path) + else: + img_path = 'epoch%.3d_%s.jpg' % (n, label) + ims.append(img_path) + txts.append(label) + links.append(img_path) + if len(ims) < 10: + webpage.add_images(ims, txts, links, width=self.win_size) + else: + num = int(round(len(ims)/2.0)) + webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size) + webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size) + webpage.save() + + # errors: dictionary of error labels and values + def plot_current_errors(self, errors, step): + if self.tf_log: + for tag, value in errors.items(): + summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)]) + self.writer.add_summary(summary, step) + + # errors: same format as |errors| of plotCurrentErrors + def print_current_errors(self, epoch, i, errors, t, lr): + message = '(epoch: %d, iters: %d, time: %.3f lr: %.5f) ' % (epoch, i, t, lr) + for k, v in errors.items(): + if v != 0: + message += '%s: %.3f ' % (k, v) + + print(message) + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message) + + + def print_save(self,message): + + print(message) + + with open(self.log_name,"a") as log_file: + log_file.write('%s\n'%message) + + + # save image to the disk + def save_images(self, webpage, visuals, image_path): + image_dir = webpage.get_image_dir() + short_path = ntpath.basename(image_path[0]) + name = os.path.splitext(short_path)[0] + + webpage.add_header(name) + ims = [] + txts = [] + links = [] + + for label, image_numpy in visuals.items(): + image_name = '%s_%s.jpg' % (name, label) + save_path = os.path.join(image_dir, image_name) + util.save_image(image_numpy, save_path) + + ims.append(image_name) + txts.append(label) + links.append(image_name) + webpage.add_images(ims, txts, links, width=self.win_size) diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..9e841e7a26e4eb057b24511e7b92d42b257a80e5 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ + MIT License + + Copyright (c) Microsoft Corporation. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE diff --git a/README.md b/README.md index 6adb7c53fff72896b13b4183eb748af87df47c8f..2f57199c1f45de1150ec5a44e9d69dfe2ac4a8b2 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,37 @@ --- -title: Image_Restoration_Colorization -emoji: 📚 -colorFrom: gray -colorTo: indigo +title: ImageRestoration +emoji: 🤗 +colorFrom: blue +colorTo: yellow sdk: gradio app_file: app.py pinned: false --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference +# Configuration + +`title`: _string_ +Display title for the Space + +`emoji`: _string_ +Space emoji (emoji-only character allowed) + +`colorFrom`: _string_ +Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray) + +`colorTo`: _string_ +Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray) + +`sdk`: _string_ +Can be either `gradio` or `streamlit` + +`sdk_version` : _string_ +Only applicable for `streamlit` SDK. +See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions. + +`app_file`: _string_ +Path to your main application file (which contains either `gradio` or `streamlit` Python code). +Path is relative to the root of the repository. + +`pinned`: _boolean_ +Whether the Space stays on top of your list. \ No newline at end of file diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000000000000000000000000000000000000..f7b89984f0fb5dd204028bc525e19eefc0859f4f --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,41 @@ + + +## Security + +Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). + +If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. + +## Reporting Security Issues + +**Please do not report security vulnerabilities through public GitHub issues.** + +Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). + +If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). + +You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). + +Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: + + * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) + * Full paths of source file(s) related to the manifestation of the issue + * The location of the affected source code (tag/branch/commit or direct URL) + * Any special configuration required to reproduce the issue + * Step-by-step instructions to reproduce the issue + * Proof-of-concept or exploit code (if possible) + * Impact of the issue, including how an attacker might exploit the issue + +This information will help us triage your report more quickly. + +If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. + +## Preferred Languages + +We prefer all communications to be in English. + +## Policy + +Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). + + \ No newline at end of file diff --git a/ansible.yaml b/ansible.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b10f750194ed582679a8810ba2a9688c726ce4a4 --- /dev/null +++ b/ansible.yaml @@ -0,0 +1,107 @@ +--- +- name: Bringing-Old-Photos-Back-to-Life + hosts: all + gather_facts: no + +# Succesfully tested on Ubuntu 18.04\20.04 and Debian 10 + + pre_tasks: + - name: install packages + package: + name: + - python3 + - python3-pip + - python3-venv + - git + - unzip + - tar + - lbzip2 + - build-essential + - cmake + - ffmpeg + - libsm6 + - libxext6 + - libgl1-mesa-glx + state: latest + become: yes + + tasks: + - name: git clone repo + git: + repo: 'https://github.com/microsoft/Bringing-Old-Photos-Back-to-Life.git' + dest: Bringing-Old-Photos-Back-to-Life + clone: yes + + - name: requirements setup + pip: + requirements: "~/Bringing-Old-Photos-Back-to-Life/requirements.txt" + virtualenv: "~/Bringing-Old-Photos-Back-to-Life/.venv" + virtualenv_command: /usr/bin/python3 -m venv .venv + + - name: additional pip packages #requirements lack some packs + pip: + name: + - setuptools + - wheel + - scikit-build + virtualenv: "~/Bringing-Old-Photos-Back-to-Life/.venv" + virtualenv_command: /usr/bin/python3 -m venv .venv + + - name: git clone batchnorm-pytorch + git: + repo: 'https://github.com/vacancy/Synchronized-BatchNorm-PyTorch' + dest: Synchronized-BatchNorm-PyTorch + clone: yes + + - name: copy sync_batchnorm to face_enhancement + copy: + src: Synchronized-BatchNorm-PyTorch/sync_batchnorm + dest: Bringing-Old-Photos-Back-to-Life/Face_Enhancement/models/networks/ + remote_src: yes + + - name: copy sync_batchnorm to global + copy: + src: Synchronized-BatchNorm-PyTorch/sync_batchnorm + dest: Bringing-Old-Photos-Back-to-Life/Global/detection_models + remote_src: yes + + - name: check if shape_predictor_68_face_landmarks.dat + stat: + path: Bringing-Old-Photos-Back-to-Life/Face_Detection/shape_predictor_68_face_landmarks.dat + register: p + + - name: get shape_predictor_68_face_landmarks.dat.bz2 + get_url: + url: http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 + dest: Bringing-Old-Photos-Back-to-Life/Face_Detection/ + when: p.stat.exists == False + + - name: unarchive shape_predictor_68_face_landmarks.dat.bz2 + shell: 'bzip2 -d Bringing-Old-Photos-Back-to-Life/Face_Detection/shape_predictor_68_face_landmarks.dat.bz2' + when: p.stat.exists == False + + - name: check if face_enhancement + stat: + path: Bringing-Old-Photos-Back-to-Life/Face_Enhancement/checkpoints/Setting_9_epoch_100/latest_net_G.pth + register: fc + + - name: unarchive Face_Enhancement/checkpoints.zip + unarchive: + src: https://facevc.blob.core.windows.net/zhanbo/old_photo/pretrain/Face_Enhancement/checkpoints.zip + dest: Bringing-Old-Photos-Back-to-Life/Face_Enhancement/ + remote_src: yes + when: fc.stat.exists == False + + - name: check if global + stat: + path: Bringing-Old-Photos-Back-to-Life/Global/checkpoints/detection/FT_Epoch_latest.pt + register: gc + + - name: unarchive Global/checkpoints.zip + unarchive: + src: https://facevc.blob.core.windows.net/zhanbo/old_photo/pretrain/Global/checkpoints.zip + dest: Bringing-Old-Photos-Back-to-Life/Global/ + remote_src: yes + when: gc.stat.exists == False + +# Do not forget to execute 'source .venv/bin/activate' inside Bringing-Old-Photos-Back-to-Life before starting run.py \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..6828e875bdb092d2241e91585fb9986f9dd06eac --- /dev/null +++ b/app.py @@ -0,0 +1,133 @@ +import gradio as gr +import os +import cv2 +import shutil +import sys +from subprocess import call +import torch +import numpy as np +from skimage import color +import torchvision.transforms as transforms +from PIL import Image +import torch + +os.system("pip install dlib") +os.system('bash setup.sh') + +def lab2rgb(L, AB): + """Convert an Lab tensor image to a RGB numpy output + Parameters: + L (1-channel tensor array): L channel images (range: [-1, 1], torch tensor array) + AB (2-channel tensor array): ab channel images (range: [-1, 1], torch tensor array) + + Returns: + rgb (RGB numpy image): rgb output images (range: [0, 255], numpy array) + """ + AB2 = AB * 110.0 + L2 = (L + 1.0) * 50.0 + Lab = torch.cat([L2, AB2], dim=1) + Lab = Lab[0].data.cpu().float().numpy() + Lab = np.transpose(Lab.astype(np.float64), (1, 2, 0)) + rgb = color.lab2rgb(Lab) * 255 + return rgb + +def get_transform(params=None, grayscale=False, method=Image.BICUBIC): + #params + preprocess = 'resize_and_crop' + load_size = 256 + crop_size = 256 + transform_list = [] + if grayscale: + transform_list.append(transforms.Grayscale(1)) + if 'resize' in preprocess: + osize = [load_size, load_size] + transform_list.append(transforms.Resize(osize, method)) + if 'crop' in preprocess: + if params is None: + transform_list.append(transforms.RandomCrop(crop_size)) + + return transforms.Compose(transform_list) + +def inferColorization(img,model_name): + print(model_name) + if model_name == "Pix2Pix_resnet9b": + model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'pix2pixColorization_resnet9b') + elif model_name == "Pix2Pix_unet256": + model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'pix2pixColorization_unet256') + elif model_name == "Deoldify": + model = torch.hub.load('manhkhanhad/ImageRestorationInfer', 'DeOldifyColorization') + transform_list = [ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ] + transform = transforms.Compose(transform_list) + #a = transforms.ToTensor()(a) + img = img.convert('L') + img = transform(img) + img = torch.unsqueeze(img, 0) + result = model(img) + + result = result[0].detach() + result = (result +1)/2.0 + + #img = transforms.Grayscale(3)(img) + #img = transforms.ToTensor()(img) + #img = torch.unsqueeze(img, 0) + #result = model(img) + #result = torch.clip(result, min=0, max=1) + image_pil = transforms.ToPILImage()(result) + return image_pil + + transform_seq = get_transform() + im = transform_seq(img) + im = np.array(img) + lab = color.rgb2lab(im).astype(np.float32) + lab_t = transforms.ToTensor()(lab) + A = lab_t[[0], ...] / 50.0 - 1.0 + B = lab_t[[1, 2], ...] / 110.0 + #data = {'A': A, 'B': B, 'A_paths': "", 'B_paths': ""} + L = torch.unsqueeze(A, 0) + #print(L.shape) + ab = model(L) + Lab = lab2rgb(L, ab).astype(np.uint8) + image_pil = Image.fromarray(Lab) + #image_pil.save('test.png') + #print(Lab.shape) + return image_pil + +def colorizaition(image,model_name): + image = Image.fromarray(image) + result = inferColorization(image,model_name) + return result + + +def run_cmd(command): + try: + call(command, shell=True) + except KeyboardInterrupt: + print("Process interrupted") + sys.exit(1) + +def run(image): + os.makedirs("Temp") + os.makedirs("Temp/input") + print(type(image)) + cv2.imwrite("Temp/input/input_img.png", image) + + command = ("python run.py --input_folder " + + "Temp/input" + + " --output_folder " + + "Temp" + + " --GPU " + + "-1" + + " --with_scratch") + run_cmd(command) + + result_restoration = Image.open("Temp/final_output/input_img.png") + shutil.rmtree("Temp") + + result_colorization = inferColorization(result_restoration,"Deoldify") + + return result_colorization + +iface = gr.Interface(fn=run, inputs="image", outputs="image").launch(debug=True,share=True) \ No newline at end of file diff --git a/cog.yaml b/cog.yaml new file mode 100755 index 0000000000000000000000000000000000000000..19cc6749b7bbbe3ec606731a0d51c5b26b946004 --- /dev/null +++ b/cog.yaml @@ -0,0 +1,26 @@ +build: + gpu: true + python_version: "3.8" + system_packages: + - "libgl1-mesa-glx" + - "libglib2.0-0" + python_packages: + - "cmake==3.21.2" + - "torchvision==0.9.0" + - "torch==1.8.0" + - "numpy==1.19.4" + - "opencv-python==4.4.0.46" + - "scipy==1.5.3" + - "tensorboardX==2.4" + - "dominate==2.6.0" + - "easydict==1.9" + - "PyYAML==5.3.1" + - "scikit-image==0.18.3" + - "dill==0.3.4" + - "einops==0.3.0" + - "PySimpleGUI==4.46.0" + - "ipython==7.19.0" + run: + - pip install dlib + +predict: "predict.py:Predictor" diff --git a/download-weights b/download-weights new file mode 100755 index 0000000000000000000000000000000000000000..481781019d569033cb5e68bc3944e10374cc9077 --- /dev/null +++ b/download-weights @@ -0,0 +1,28 @@ +#!/bin/sh + +cd Face_Enhancement/models/networks +git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm . +cd ../../../ + +cd Global/detection_models +git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm . +cd ../../ + +# download the landmark detection model +cd Face_Detection/ +wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 +bzip2 -d shape_predictor_68_face_landmarks.dat.bz2 +cd ../ + +# download the pretrained model +cd Face_Enhancement/ +wget https://facevc.blob.core.windows.net/zhanbo/old_photo/pretrain/Face_Enhancement/checkpoints.zip +unzip checkpoints.zip +cd ../ + +cd Global/ +wget https://facevc.blob.core.windows.net/zhanbo/old_photo/pretrain/Global/checkpoints.zip +unzip checkpoints.zip +cd ../ diff --git a/kubernetes-pod.yml b/kubernetes-pod.yml new file mode 100644 index 0000000000000000000000000000000000000000..83ab82ac2a1b9638ca5002460448e7ce9067ad15 --- /dev/null +++ b/kubernetes-pod.yml @@ -0,0 +1,38 @@ +apiVersion: v1 +kind: Pod +metadata: + name: photo-back2life +spec: + containers: + - name: photos-back2life + image: + volumeMounts: + - mountPath: /in + name: in-folder + - mountPath: /out + name: out-folder + command: + - python + - /app/run.py + args: + - --input_folder + - /in + - --output_folder + - /out + - --GPU + - '0' + - --with_scratch + resources: + limits: + memory: 4Gi + cpu: 0 + nvidia.com/gpu: 1 + volumes: + - name: in-folder + hostPath: + path: /srv/in + type: Directory + - name: out-folder + hostPath: + path: /srv/out + type: Directory diff --git a/packages.txt b/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..8ba3b394be4f59e94a9e575ba001272cba649c0f --- /dev/null +++ b/packages.txt @@ -0,0 +1 @@ +ffmpeg libsm6 libxext6 -y diff --git a/predict.py b/predict.py new file mode 100755 index 0000000000000000000000000000000000000000..5573cd1a64d8357641299517338011e7e1aa1ac1 --- /dev/null +++ b/predict.py @@ -0,0 +1,222 @@ +import tempfile +from pathlib import Path +import argparse +import shutil +import os +import glob +import cv2 +import cog +from run import run_cmd + + +class Predictor(cog.Predictor): + def setup(self): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_folder", type=str, default="input/cog_temp", help="Test images" + ) + parser.add_argument( + "--output_folder", + type=str, + default="output", + help="Restored images, please use the absolute path", + ) + parser.add_argument("--GPU", type=str, default="0", help="0,1,2") + parser.add_argument( + "--checkpoint_name", + type=str, + default="Setting_9_epoch_100", + help="choose which checkpoint", + ) + self.opts = parser.parse_args("") + self.basepath = os.getcwd() + self.opts.input_folder = os.path.join(self.basepath, self.opts.input_folder) + self.opts.output_folder = os.path.join(self.basepath, self.opts.output_folder) + os.makedirs(self.opts.input_folder, exist_ok=True) + os.makedirs(self.opts.output_folder, exist_ok=True) + + @cog.input("image", type=Path, help="input image") + @cog.input( + "HR", + type=bool, + default=False, + help="whether the input image is high-resolution", + ) + @cog.input( + "with_scratch", + type=bool, + default=False, + help="whether the input image is scratched", + ) + def predict(self, image, HR=False, with_scratch=False): + try: + os.chdir(self.basepath) + input_path = os.path.join(self.opts.input_folder, os.path.basename(image)) + shutil.copy(str(image), input_path) + + gpu1 = self.opts.GPU + + ## Stage 1: Overall Quality Improve + print("Running Stage 1: Overall restoration") + os.chdir("./Global") + stage_1_input_dir = self.opts.input_folder + stage_1_output_dir = os.path.join( + self.opts.output_folder, "stage_1_restore_output" + ) + + os.makedirs(stage_1_output_dir, exist_ok=True) + + if not with_scratch: + + stage_1_command = ( + "python test.py --test_mode Full --Quality_restore --test_input " + + stage_1_input_dir + + " --outputs_dir " + + stage_1_output_dir + + " --gpu_ids " + + gpu1 + ) + run_cmd(stage_1_command) + else: + + mask_dir = os.path.join(stage_1_output_dir, "masks") + new_input = os.path.join(mask_dir, "input") + new_mask = os.path.join(mask_dir, "mask") + stage_1_command_1 = ( + "python detection.py --test_path " + + stage_1_input_dir + + " --output_dir " + + mask_dir + + " --input_size full_size" + + " --GPU " + + gpu1 + ) + + if HR: + HR_suffix = " --HR" + else: + HR_suffix = "" + + stage_1_command_2 = ( + "python test.py --Scratch_and_Quality_restore --test_input " + + new_input + + " --test_mask " + + new_mask + + " --outputs_dir " + + stage_1_output_dir + + " --gpu_ids " + + gpu1 + + HR_suffix + ) + + run_cmd(stage_1_command_1) + run_cmd(stage_1_command_2) + + ## Solve the case when there is no face in the old photo + stage_1_results = os.path.join(stage_1_output_dir, "restored_image") + stage_4_output_dir = os.path.join(self.opts.output_folder, "final_output") + os.makedirs(stage_4_output_dir, exist_ok=True) + for x in os.listdir(stage_1_results): + img_dir = os.path.join(stage_1_results, x) + shutil.copy(img_dir, stage_4_output_dir) + + print("Finish Stage 1 ...") + print("\n") + + ## Stage 2: Face Detection + + print("Running Stage 2: Face Detection") + os.chdir(".././Face_Detection") + stage_2_input_dir = os.path.join(stage_1_output_dir, "restored_image") + stage_2_output_dir = os.path.join( + self.opts.output_folder, "stage_2_detection_output" + ) + os.makedirs(stage_2_output_dir, exist_ok=True) + + stage_2_command = ( + "python detect_all_dlib_HR.py --url " + + stage_2_input_dir + + " --save_url " + + stage_2_output_dir + ) + + run_cmd(stage_2_command) + print("Finish Stage 2 ...") + print("\n") + + ## Stage 3: Face Restore + print("Running Stage 3: Face Enhancement") + os.chdir(".././Face_Enhancement") + stage_3_input_mask = "./" + stage_3_input_face = stage_2_output_dir + stage_3_output_dir = os.path.join( + self.opts.output_folder, "stage_3_face_output" + ) + + os.makedirs(stage_3_output_dir, exist_ok=True) + + self.opts.checkpoint_name = "FaceSR_512" + stage_3_command = ( + "python test_face.py --old_face_folder " + + stage_3_input_face + + " --old_face_label_folder " + + stage_3_input_mask + + " --tensorboard_log --name " + + self.opts.checkpoint_name + + " --gpu_ids " + + gpu1 + + " --load_size 512 --label_nc 18 --no_instance --preprocess_mode resize --batchSize 1 --results_dir " + + stage_3_output_dir + + " --no_parsing_map" + ) + + run_cmd(stage_3_command) + print("Finish Stage 3 ...") + print("\n") + + ## Stage 4: Warp back + print("Running Stage 4: Blending") + os.chdir(".././Face_Detection") + stage_4_input_image_dir = os.path.join(stage_1_output_dir, "restored_image") + stage_4_input_face_dir = os.path.join(stage_3_output_dir, "each_img") + stage_4_output_dir = os.path.join(self.opts.output_folder, "final_output") + os.makedirs(stage_4_output_dir, exist_ok=True) + + stage_4_command = ( + "python align_warp_back_multiple_dlib_HR.py --origin_url " + + stage_4_input_image_dir + + " --replace_url " + + stage_4_input_face_dir + + " --save_url " + + stage_4_output_dir + ) + + run_cmd(stage_4_command) + print("Finish Stage 4 ...") + print("\n") + + print("All the processing is done. Please check the results.") + + final_output = os.listdir(os.path.join(self.opts.output_folder, "final_output"))[0] + + image_restore = cv2.imread(os.path.join(self.opts.output_folder, "final_output", final_output)) + + out_path = Path(tempfile.mkdtemp()) / "out.png" + + cv2.imwrite(str(out_path), image_restore) + finally: + clean_folder(self.opts.input_folder) + clean_folder(self.opts.output_folder) + return out_path + + +def clean_folder(folder): + for filename in os.listdir(folder): + file_path = os.path.join(folder, filename) + try: + if os.path.isfile(file_path) or os.path.islink(file_path): + os.unlink(file_path) + elif os.path.isdir(file_path): + shutil.rmtree(file_path) + except Exception as e: + print(f"Failed to delete {file_path}. Reason:{e}") diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..8e97b30762e64361feaad2792ea190f6751ad59a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,15 @@ +torch +torchvision +#dlib +scikit-image +easydict +PyYAML +dominate>=2.4.0 +dill +tensorboardX +scipy +opencv-python +einops +PySimpleGUI +fastai==2.4.0 +visdom>=0.1.8.8 \ No newline at end of file diff --git a/run.py b/run.py new file mode 100644 index 0000000000000000000000000000000000000000..d78448600473d74939d4a820e1b9910f46cc8034 --- /dev/null +++ b/run.py @@ -0,0 +1,202 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os +import argparse +import shutil +import sys +from subprocess import call + +def run_cmd(command): + try: + call(command, shell=True) + except KeyboardInterrupt: + print("Process interrupted") + sys.exit(1) + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("--input_folder", type=str, default="./test_images/old", help="Test images") + parser.add_argument( + "--output_folder", + type=str, + default="./output", + help="Restored images, please use the absolute path", + ) + parser.add_argument("--GPU", type=str, default="6,7", help="0,1,2") + parser.add_argument( + "--checkpoint_name", type=str, default="Setting_9_epoch_100", help="choose which checkpoint" + ) + parser.add_argument("--with_scratch", action="store_true") + parser.add_argument("--HR", action='store_true') + opts = parser.parse_args() + + gpu1 = opts.GPU + + # resolve relative paths before changing directory + opts.input_folder = os.path.abspath(opts.input_folder) + opts.output_folder = os.path.abspath(opts.output_folder) + if not os.path.exists(opts.output_folder): + os.makedirs(opts.output_folder) + + main_environment = os.getcwd() + + ## Stage 1: Overall Quality Improve + print("Running Stage 1: Overall restoration") + os.chdir("./Global") + stage_1_input_dir = opts.input_folder + stage_1_output_dir = os.path.join(opts.output_folder, "stage_1_restore_output") + if not os.path.exists(stage_1_output_dir): + os.makedirs(stage_1_output_dir) + + if not opts.with_scratch: + stage_1_command = ( + "python test.py --test_mode Full --Quality_restore --test_input " + + stage_1_input_dir + + " --outputs_dir " + + stage_1_output_dir + + " --gpu_ids " + + gpu1 + ) + run_cmd(stage_1_command) + else: + + mask_dir = os.path.join(stage_1_output_dir, "masks") + new_input = os.path.join(mask_dir, "input") + new_mask = os.path.join(mask_dir, "mask") + stage_1_command_1 = ( + "python detection.py --test_path " + + stage_1_input_dir + + " --output_dir " + + mask_dir + + " --input_size full_size" + + " --GPU " + + gpu1 + ) + + if opts.HR: + HR_suffix=" --HR" + else: + HR_suffix="" + + stage_1_command_2 = ( + "python test.py --Scratch_and_Quality_restore --test_input " + + new_input + + " --test_mask " + + new_mask + + " --outputs_dir " + + stage_1_output_dir + + " --gpu_ids " + + gpu1 + HR_suffix + ) + + run_cmd(stage_1_command_1) + run_cmd(stage_1_command_2) + + ## Solve the case when there is no face in the old photo + stage_1_results = os.path.join(stage_1_output_dir, "restored_image") + stage_4_output_dir = os.path.join(opts.output_folder, "final_output") + if not os.path.exists(stage_4_output_dir): + os.makedirs(stage_4_output_dir) + for x in os.listdir(stage_1_results): + img_dir = os.path.join(stage_1_results, x) + shutil.copy(img_dir, stage_4_output_dir) + + print("Finish Stage 1 ...") + print("\n") + + ## Stage 2: Face Detection + + print("Running Stage 2: Face Detection") + os.chdir(".././Face_Detection") + stage_2_input_dir = os.path.join(stage_1_output_dir, "restored_image") + stage_2_output_dir = os.path.join(opts.output_folder, "stage_2_detection_output") + if not os.path.exists(stage_2_output_dir): + os.makedirs(stage_2_output_dir) + if opts.HR: + stage_2_command = ( + "python detect_all_dlib_HR.py --url " + stage_2_input_dir + " --save_url " + stage_2_output_dir + ) + else: + stage_2_command = ( + "python detect_all_dlib.py --url " + stage_2_input_dir + " --save_url " + stage_2_output_dir + ) + run_cmd(stage_2_command) + print("Finish Stage 2 ...") + print("\n") + + ## Stage 3: Face Restore + print("Running Stage 3: Face Enhancement") + os.chdir(".././Face_Enhancement") + stage_3_input_mask = "./" + stage_3_input_face = stage_2_output_dir + stage_3_output_dir = os.path.join(opts.output_folder, "stage_3_face_output") + if not os.path.exists(stage_3_output_dir): + os.makedirs(stage_3_output_dir) + + if opts.HR: + opts.checkpoint_name='FaceSR_512' + stage_3_command = ( + "python test_face.py --old_face_folder " + + stage_3_input_face + + " --old_face_label_folder " + + stage_3_input_mask + + " --tensorboard_log --name " + + opts.checkpoint_name + + " --gpu_ids " + + gpu1 + + " --load_size 512 --label_nc 18 --no_instance --preprocess_mode resize --batchSize 1 --results_dir " + + stage_3_output_dir + + " --no_parsing_map" + ) + else: + stage_3_command = ( + "python test_face.py --old_face_folder " + + stage_3_input_face + + " --old_face_label_folder " + + stage_3_input_mask + + " --tensorboard_log --name " + + opts.checkpoint_name + + " --gpu_ids " + + gpu1 + + " --load_size 256 --label_nc 18 --no_instance --preprocess_mode resize --batchSize 4 --results_dir " + + stage_3_output_dir + + " --no_parsing_map" + ) + run_cmd(stage_3_command) + print("Finish Stage 3 ...") + print("\n") + + ## Stage 4: Warp back + print("Running Stage 4: Blending") + os.chdir(".././Face_Detection") + stage_4_input_image_dir = os.path.join(stage_1_output_dir, "restored_image") + stage_4_input_face_dir = os.path.join(stage_3_output_dir, "each_img") + stage_4_output_dir = os.path.join(opts.output_folder, "final_output") + if not os.path.exists(stage_4_output_dir): + os.makedirs(stage_4_output_dir) + if opts.HR: + stage_4_command = ( + "python align_warp_back_multiple_dlib_HR.py --origin_url " + + stage_4_input_image_dir + + " --replace_url " + + stage_4_input_face_dir + + " --save_url " + + stage_4_output_dir + ) + else: + stage_4_command = ( + "python align_warp_back_multiple_dlib.py --origin_url " + + stage_4_input_image_dir + + " --replace_url " + + stage_4_input_face_dir + + " --save_url " + + stage_4_output_dir + ) + run_cmd(stage_4_command) + print("Finish Stage 4 ...") + print("\n") + + print("All the processing is done. Please check the results.") + diff --git a/setup.sh b/setup.sh new file mode 100644 index 0000000000000000000000000000000000000000..9e523ad08019bea2b4839a4071d497cb46d85e12 --- /dev/null +++ b/setup.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +cd Face_Enhancement/models/networks/ +git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm . +cd ../../../ + +cd Global/detection_models +git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm . +cd ../../ + +cd Face_Detection/ +wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 +bzip2 -d shape_predictor_68_face_landmarks.dat.bz2 +cd ../ + +cd Face_Enhancement/ +wget https://github.com/microsoft/Bringing-Old-Photos-Back-to-Life/releases/download/v1.0/face_checkpoints.zip +unzip face_checkpoints.zip +cd ../ + +cd Global/ +wget https://github.com/microsoft/Bringing-Old-Photos-Back-to-Life/releases/download/v1.0/global_checkpoints.zip +unzip global_checkpoints.zip +cd ../ \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..38824291160deec62dafd5865fdbebc1824c3d3b --- /dev/null +++ b/utils.py @@ -0,0 +1,32 @@ +import cv2 +import os +import cv2 +import shutil +import sys +from subprocess import call + +def run_cmd(command): + try: + call(command, shell=True) + except KeyboardInterrupt: + print("Process interrupted") + sys.exit(1) + +def Restoration(image): + os.makedirs("Temp") + os.makedirs("Temp/input") + print(type(image)) + cv2.imwrite("Temp/input/input_img.png", image) + + command = ("python run.py --input_folder " + + "Temp/input" + + " --output_folder " + + "Temp" + + " --GPU " + + "-1" + + " --with_scratch") + run_cmd(command) + + result = cv2.imread("Temp/final_output/input_img.png") + shutil.rmtree("Temp") + return result \ No newline at end of file