Ebost commited on
Commit
d1f35db
1 Parent(s): a4413db

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system("git clone https://github.com/bryandlee/animegan2-pytorch")
3
+ os.system("gdown https://drive.google.com/uc?id=1WK5Mdt6mwlcsqCZMHkCUSDJxN1UyFi0-")
4
+ os.system("gdown https://drive.google.com/uc?id=18H3iK09_d54qEDoWIc82SyWB2xun4gjU")
5
+ import sys
6
+ sys.path.append("animegan2-pytorch")
7
+
8
+ import torch
9
+ torch.set_grad_enabled(False)
10
+
11
+ from model import Generator
12
+
13
+ device = "cpu"
14
+
15
+ model = Generator().eval().to(device)
16
+ model.load_state_dict(torch.load("face_paint_512_v2_0.pt"))
17
+
18
+ from PIL import Image
19
+ from torchvision.transforms.functional import to_tensor, to_pil_image
20
+ import gradio as gr
21
+
22
+ def face2paint(
23
+ img: Image.Image,
24
+ size: int,
25
+ side_by_side: bool = False,
26
+ ) -> Image.Image:
27
+
28
+
29
+ input = to_tensor(img).unsqueeze(0) * 2 - 1
30
+ output = model(input.to(device)).cpu()[0]
31
+
32
+ if side_by_side:
33
+ output = torch.cat([input[0], output], dim=2)
34
+
35
+ output = (output * 0.5 + 0.5).clip(0, 1)
36
+
37
+ return to_pil_image(output)
38
+
39
+
40
+
41
+
42
+ import os
43
+ import collections
44
+ from typing import Union, List
45
+ import numpy as np
46
+ from PIL import Image
47
+
48
+
49
+ import PIL.Image
50
+ import PIL.ImageFile
51
+ import numpy as np
52
+ import scipy.ndimage
53
+
54
+
55
+ import requests
56
+
57
+ def inference(img):
58
+ out = face2paint(img, 512)
59
+ return out
60
+
61
+
62
+ title = "Animeganv2"
63
+ description = "Gradio demo for AnimeGanv2 Face Portrait v2. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
64
+ article = "<p style='text-align: center'><a href='https://github.com/bryandlee/animegan2-pytorch' target='_blank'>Github Repo</a></p>"
65
+
66
+ examples=[['groot.jpeg']]
67
+ gr.Interface(inference, gr.inputs.Image(type="pil",shape=(512,512)), gr.outputs.Image(type="pil"),title=title,description=description,article=article,examples=examples,enable_queue=True).launch()