Doron Adler commited on
Commit
27b9d2f
1 Parent(s): 843e13d

Dragness - Pixel2Style2Pixel based face2drag

Browse files

.gitattributes CHANGED
@@ -25,3 +25,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ shape_predictor_5_face_landmarks.dat filter=lfs diff=lfs merge=lfs -text
Example00001.jpg ADDED
Binary file
Example00002.jpg ADDED
Binary file
Example00003.jpg ADDED
Binary file
Example00004.jpg ADDED
Binary file
Example00005.jpg ADDED
Binary file
Example00006.jpg ADDED
Binary file
Example00007.jpg ADDED
Binary file
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Dragness
3
- emoji: 🐢
4
- colorFrom: green
5
- colorTo: purple
6
  sdk: gradio
7
  app_file: app.py
8
  pinned: false
1
  ---
2
  title: Dragness
3
+ emoji: 👸
4
+ colorFrom: yellow
5
+ colorTo: blue
6
  sdk: gradio
7
  app_file: app.py
8
  pinned: false
Sample00001.jpg ADDED
Binary file
Sample00002.jpg ADDED
Binary file
Sample00003.jpg ADDED
Binary file
Sample00004.jpg ADDED
Binary file
Sample00005.jpg ADDED
Binary file
Sample00006.jpg ADDED
Binary file
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ #os.system("gdown https://drive.google.com/uc?id=1WEST2O6svlQWpJNomX3947Q2bfJz4bAJ")
3
+ #os.system("gdown https://drive.google.com/uc?id=1CbnhlUI9Tms2o7S2eCg9qwGXZFCyROYy")
4
+ import sys
5
+ import face_detection
6
+ import PIL
7
+ from PIL import Image, ImageOps
8
+ import numpy as np
9
+
10
+ import torch
11
+ torch.set_grad_enabled(False)
12
+ net = torch.jit.load('dragness_p2s2p_torchscript_cpu.pt')
13
+ net.eval()
14
+
15
+
16
+ def tensor2im(var):
17
+ var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy()
18
+ var = ((var + 1) / 2)
19
+ var[var < 0] = 0
20
+ var[var > 1] = 1
21
+ var = var * 255
22
+ return Image.fromarray(var.astype('uint8'))
23
+
24
+ def image_as_array(image_in):
25
+ im_array = np.array(image_in, np.float32)
26
+ im_array = (im_array/255)*2 - 1
27
+ im_array = np.transpose(im_array, (2, 0, 1))
28
+ im_array = np.expand_dims(im_array, 0)
29
+ return im_array
30
+
31
+ def find_aligned_face(image_in, size=256):
32
+ aligned_image, n_faces, quad = face_detection.align(image_in, face_index=0, output_size=size)
33
+ return aligned_image, n_faces, quad
34
+
35
+ def align_first_face(image_in, size=256):
36
+ aligned_image, n_faces, quad = find_aligned_face(image_in,size=size)
37
+ if n_faces == 0:
38
+ image_in = image_in.resize((size, size))
39
+ im_array = image_as_array(image_in)
40
+ else:
41
+ im_array = image_as_array(aligned_image)
42
+
43
+ return im_array
44
+
45
+ import gradio as gr
46
+
47
+ def face2drag(
48
+ img: Image.Image,
49
+ size: int
50
+ ) -> Image.Image:
51
+
52
+ aligned_img = align_first_face(img)
53
+ input = torch.Tensor(aligned_img)
54
+ output = net(input)
55
+ output = tensor2im(output[0])
56
+ return output
57
+
58
+ import os
59
+ import collections
60
+ from typing import Union, List
61
+ import numpy as np
62
+ from PIL import Image
63
+ import PIL.Image
64
+ import PIL.ImageFile
65
+ import numpy as np
66
+ import scipy.ndimage
67
+ import requests
68
+
69
+ def inference(img):
70
+ out = face2drag(img, 256)
71
+ return out
72
+
73
+
74
+ title = "Dragness"
75
+ description = "Gradio demo for Drag finetuned Pixel2Style2Pixel. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
76
+ article = "<p style='text-align: center'><a href='https://github.com/justinpinkney/pixel2style2pixel/tree/nw' target='_blank'>Github Repo</a></p><p style='text-align: center'>samples: <img src='Sample00001.jpg' alt='Sample00001'/><img src='Sample00002.jpg' alt='Sample00002'/><img src='Sample00003.jpg' alt='Sample00003'/><img src='Sample00004.jpg' alt='Sample00004'/><img src='Sample00005.jpg' alt='Sample00005'/><img src='Sample00006.jpg' alt='Sample00006'/></p><p>Drag model was fine tuned by Doron Adler</p>"
77
+
78
+ examples=[['Example00001.jpg'],['Example00002.jpg'],['Fiona-Example00003.jpg'],['Example00004.jpg'],['Example00005.jpg'],['Example00006.jpg'],['Example00007.jpg']]
79
+ gr.Interface(inference, gr.inputs.Image(type="pil",shape=(256,256)), gr.outputs.Image(type="pil"),title=title,description=description,article=article,examples=examples,enable_queue=True).launch()
dragness_p2s2p_torchscript_cpu.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c758ffe5265d2041a71f369d3eb3565f187aa6bd39f647e1124a66cd79a26f3c
3
+ size 1202678391
face_detection.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Justin Pinkney
2
+
3
+ import dlib
4
+ import numpy as np
5
+ import os
6
+ from PIL import Image
7
+ from PIL import ImageOps
8
+ from scipy.ndimage import gaussian_filter
9
+ import cv2
10
+
11
+
12
+ MODEL_PATH = "shape_predictor_5_face_landmarks.dat"
13
+ detector = dlib.get_frontal_face_detector()
14
+
15
+
16
+ def align(image_in, face_index=0, output_size=256):
17
+ landmarks = list(get_landmarks(image_in))
18
+ n_faces = len(landmarks)
19
+ face_index = min(n_faces-1, face_index)
20
+ if n_faces == 0:
21
+ aligned_image = image_in
22
+ quad = None
23
+ else:
24
+ aligned_image, quad = image_align(image_in, landmarks[face_index], output_size=output_size)
25
+
26
+ return aligned_image, n_faces, quad
27
+
28
+
29
+ def composite_images(quad, img, output):
30
+ """Composite an image into and output canvas according to transformed co-ords"""
31
+ output = output.convert("RGBA")
32
+ img = img.convert("RGBA")
33
+ input_size = img.size
34
+ src = np.array(((0, 0), (0, input_size[1]), input_size, (input_size[0], 0)), dtype=np.float32)
35
+ dst = np.float32(quad)
36
+ mtx = cv2.getPerspectiveTransform(dst, src)
37
+ img = img.transform(output.size, Image.PERSPECTIVE, mtx.flatten(), Image.BILINEAR)
38
+ output.alpha_composite(img)
39
+
40
+ return output.convert("RGB")
41
+
42
+
43
+ def get_landmarks(image):
44
+ """Get landmarks from PIL image"""
45
+ shape_predictor = dlib.shape_predictor(MODEL_PATH)
46
+
47
+ max_size = max(image.size)
48
+ reduction_scale = int(max_size/512)
49
+ if reduction_scale == 0:
50
+ reduction_scale = 1
51
+ downscaled = image.reduce(reduction_scale)
52
+ img = np.array(downscaled)
53
+ detections = detector(img, 0)
54
+
55
+ for detection in detections:
56
+ try:
57
+ face_landmarks = [(reduction_scale*item.x, reduction_scale*item.y) for item in shape_predictor(img, detection).parts()]
58
+ yield face_landmarks
59
+ except Exception as e:
60
+ print(e)
61
+
62
+
63
+ def image_align(src_img, face_landmarks, output_size=512, transform_size=2048, enable_padding=True, x_scale=1, y_scale=1, em_scale=0.1, alpha=False):
64
+ # Align function modified from ffhq-dataset
65
+ # See https://github.com/NVlabs/ffhq-dataset for license
66
+
67
+ lm = np.array(face_landmarks)
68
+ lm_eye_left = lm[2:3] # left-clockwise
69
+ lm_eye_right = lm[0:1] # left-clockwise
70
+
71
+ # Calculate auxiliary vectors.
72
+ eye_left = np.mean(lm_eye_left, axis=0)
73
+ eye_right = np.mean(lm_eye_right, axis=0)
74
+ eye_avg = (eye_left + eye_right) * 0.5
75
+ eye_to_eye = 0.71*(eye_right - eye_left)
76
+ mouth_avg = lm[4]
77
+ eye_to_mouth = 1.35*(mouth_avg - eye_avg)
78
+
79
+ # Choose oriented crop rectangle.
80
+ x = eye_to_eye.copy()
81
+ x /= np.hypot(*x)
82
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
83
+ x *= x_scale
84
+ y = np.flipud(x) * [-y_scale, y_scale]
85
+ c = eye_avg + eye_to_mouth * em_scale
86
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
87
+ quad_orig = quad.copy()
88
+ qsize = np.hypot(*x) * 2
89
+
90
+ try:
91
+ src_img = ImageOps.exif_transpose(src_img)
92
+ except:
93
+ print("exif problem, not rotating")
94
+
95
+ img = src_img.convert('RGBA').convert('RGB')
96
+
97
+ # Shrink.
98
+ shrink = int(np.floor(qsize / output_size * 0.5))
99
+ if shrink > 1:
100
+ rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
101
+ img = img.resize(rsize, Image.ANTIALIAS)
102
+ quad /= shrink
103
+ qsize /= shrink
104
+
105
+ # Crop.
106
+ border = max(int(np.rint(qsize * 0.1)), 3)
107
+ crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
108
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1]))
109
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
110
+ img = img.crop(crop)
111
+ quad -= crop[0:2]
112
+
113
+ # Pad.
114
+ pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
115
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0))
116
+ if enable_padding and max(pad) > border - 4:
117
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
118
+ img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
119
+ h, w, _ = img.shape
120
+ y, x, _ = np.ogrid[:h, :w, :1]
121
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3]))
122
+ blur = qsize * 0.02
123
+ img += (gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
124
+ img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0)
125
+ img = np.uint8(np.clip(np.rint(img), 0, 255))
126
+ if alpha:
127
+ mask = 1-np.clip(3.0 * mask, 0.0, 1.0)
128
+ mask = np.uint8(np.clip(np.rint(mask*255), 0, 255))
129
+ img = np.concatenate((img, mask), axis=2)
130
+ img = Image.fromarray(img, 'RGBA')
131
+ else:
132
+ img = Image.fromarray(img, 'RGB')
133
+ quad += pad[:2]
134
+
135
+ # Transform.
136
+ img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(), Image.BILINEAR)
137
+ if output_size < transform_size:
138
+ img = img.resize((output_size, output_size), Image.ANTIALIAS)
139
+
140
+ return img, quad_orig
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ opencv-python
3
+ Pillow
4
+ scikit-image
5
+ torch
6
+ torchvision
7
+ ninja
8
+ dlib
9
+ gdown
10
+ scipy
11
+ cmake
shape_predictor_5_face_landmarks.dat ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4b1e9804792707d3a405c2c16a80a20269e6675021f64a41d30fffafbc41888
3
+ size 9150489