Egrt commited on
Commit
2f91d83
1 Parent(s): bf2935c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -0
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pystuck
2
+ pystuck.run_server()
3
+ import os
4
+ os.system("pip install gradio==2.5.3")
5
+ os.system("wget https://github.com/Sxela/ArcaneGAN/releases/download/v0.4/ArcaneGANv0.4.jit")
6
+ os.system("wget https://github.com/Sxela/ArcaneGAN/releases/download/v0.3/ArcaneGANv0.3.jit")
7
+ os.system("wget https://github.com/Sxela/ArcaneGAN/releases/download/v0.2/ArcaneGANv0.2.jit")
8
+ os.system("pip -qq install facenet_pytorch")
9
+ from facenet_pytorch import MTCNN
10
+ from torchvision import transforms
11
+ import torch, PIL
12
+ torch.hub.download_url_to_file('https://hf.space/gradioiframe/akhaliq/AnimeGANv2/file/bill.png', 'bill.png')
13
+ from tqdm.notebook import tqdm
14
+ import gradio as gr
15
+ import torch
16
+ mtcnn = MTCNN(image_size=256, margin=80)
17
+ # simplest ye olde trustworthy MTCNN for face detection with landmarks
18
+ def detect(img):
19
+
20
+ # Detect faces
21
+ batch_boxes, batch_probs, batch_points = mtcnn.detect(img, landmarks=True)
22
+ # Select faces
23
+ if not mtcnn.keep_all:
24
+ batch_boxes, batch_probs, batch_points = mtcnn.select_boxes(
25
+ batch_boxes, batch_probs, batch_points, img, method=mtcnn.selection_method
26
+ )
27
+
28
+ return batch_boxes, batch_points
29
+ # my version of isOdd, should make a separate repo for it :D
30
+ def makeEven(_x):
31
+ return _x if (_x % 2 == 0) else _x+1
32
+ # the actual scaler function
33
+ def scale(boxes, _img, max_res=1_500_000, target_face=256, fixed_ratio=0, max_upscale=2, VERBOSE=False):
34
+
35
+ x, y = _img.size
36
+
37
+ ratio = 2 #initial ratio
38
+
39
+ #scale to desired face size
40
+ if (boxes is not None):
41
+ if len(boxes)>0:
42
+ ratio = target_face/max(boxes[0][2:]-boxes[0][:2]);
43
+ ratio = min(ratio, max_upscale)
44
+ if VERBOSE: print('up by', ratio)
45
+ if fixed_ratio>0:
46
+ if VERBOSE: print('fixed ratio')
47
+ ratio = fixed_ratio
48
+
49
+ x*=ratio
50
+ y*=ratio
51
+
52
+ #downscale to fit into max res
53
+ res = x*y
54
+ if res > max_res:
55
+ ratio = pow(res/max_res,1/2);
56
+ if VERBOSE: print(ratio)
57
+ x=int(x/ratio)
58
+ y=int(y/ratio)
59
+
60
+ #make dimensions even, because usually NNs fail on uneven dimensions due skip connection size mismatch
61
+ x = makeEven(int(x))
62
+ y = makeEven(int(y))
63
+
64
+ size = (x, y)
65
+ return _img.resize(size)
66
+ """
67
+ A useful scaler algorithm, based on face detection.
68
+ Takes PIL.Image, returns a uniformly scaled PIL.Image
69
+ boxes: a list of detected bboxes
70
+ _img: PIL.Image
71
+ max_res: maximum pixel area to fit into. Use to stay below the VRAM limits of your GPU.
72
+ target_face: desired face size. Upscale or downscale the whole image to fit the detected face into that dimension.
73
+ fixed_ratio: fixed scale. Ignores the face size, but doesn't ignore the max_res limit.
74
+ max_upscale: maximum upscale ratio. Prevents from scaling images with tiny faces to a blurry mess.
75
+ """
76
+ def scale_by_face_size(_img, max_res=1_500_000, target_face=256, fix_ratio=0, max_upscale=2, VERBOSE=False):
77
+ boxes = None
78
+ boxes, _ = detect(_img)
79
+ if VERBOSE: print('boxes',boxes)
80
+ img_resized = scale(boxes, _img, max_res, target_face, fix_ratio, max_upscale, VERBOSE)
81
+ return img_resized
82
+ size = 256
83
+ means = [0.485, 0.456, 0.406]
84
+ stds = [0.229, 0.224, 0.225]
85
+ t_stds = torch.tensor(stds).cuda().half()[:,None,None]
86
+ t_means = torch.tensor(means).cuda().half()[:,None,None]
87
+ def makeEven(_x):
88
+ return int(_x) if (_x % 2 == 0) else int(_x+1)
89
+ img_transforms = transforms.Compose([
90
+ transforms.ToTensor(),
91
+ transforms.Normalize(means,stds)])
92
+
93
+ def tensor2im(var):
94
+ return var.mul(t_stds).add(t_means).mul(255.).clamp(0,255).permute(1,2,0)
95
+ def proc_pil_img(input_image, model):
96
+ transformed_image = img_transforms(input_image)[None,...].cuda().half()
97
+
98
+ with torch.no_grad():
99
+ result_image = model(transformed_image)[0]; print(result_image.shape)
100
+ output_image = tensor2im(result_image)
101
+ output_image = output_image.detach().cpu().numpy().astype('uint8')
102
+ output_image = PIL.Image.fromarray(output_image)
103
+ return output_image
104
+
105
+
106
+
107
+ def fit(img,maxsize=512):
108
+ maxdim = max(*img.size)
109
+ if maxdim>maxsize:
110
+ ratio = maxsize/maxdim
111
+ x,y = img.size
112
+ size = (int(x*ratio),int(y*ratio))
113
+ img = img.resize(size)
114
+ return img
115
+
116
+ modelv4 = torch.jit.load('./ArcaneGANv0.4.jit').eval().cuda().half()
117
+ modelv3 = torch.jit.load('./ArcaneGANv0.3.jit').eval().cuda().half()
118
+ modelv2 = torch.jit.load('./ArcaneGANv0.2.jit').eval().cuda().half()
119
+ def process(im, version):
120
+ if version == 'version 0.4':
121
+ im = scale_by_face_size(im, target_face=300, max_res=1_500_000, max_upscale=2)
122
+ res = proc_pil_img(im, modelv4)
123
+ elif version == 'version 0.3':
124
+ im = scale_by_face_size(im, target_face=300, max_res=1_500_000, max_upscale=2)
125
+ res = proc_pil_img(im, modelv3)
126
+ else:
127
+ im = scale_by_face_size(im, target_face=300, max_res=1_500_000, max_upscale=2)
128
+ res = proc_pil_img(im, modelv2)
129
+ return res
130
+
131
+ title = "ArcaneGAN"
132
+ description = "Gradio demo for ArcaneGAN, portrait to Arcane style. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
133
+ article = "<div style='text-align: center;'>ArcaneGan by <a href='https://twitter.com/devdef' target='_blank'>Alexander S</a> | <a href='https://github.com/Sxela/ArcaneGAN' target='_blank'>Github Repo</a> | <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_arcanegan' alt='visitor badge'></center></div>"
134
+ gr.Interface(
135
+ process,
136
+ [gr.inputs.Image(type="pil", label="Input",shape=(256,256)),gr.inputs.Radio(choices=['version 0.2','version 0.3','version 0.4'], type="value", default='version 0.4', label='version')
137
+ ],
138
+ gr.outputs.Image(type="pil", label="Output"),
139
+ title=title,
140
+ description=description,
141
+ article=article,
142
+ examples=[['bill.png','version 0.3'],['keanu.png','version 0.4'],['will.jpeg','version 0.4']],
143
+ enable_queue=True
144
+ ).launch(debug=True)