Ahsen Khaliq commited on
Commit
2061b93
1 Parent(s): f912059

pytorch to onnx

Browse files
Files changed (1) hide show
  1. app.py +33 -45
app.py CHANGED
@@ -1,63 +1,51 @@
 
 
1
  import os
2
- os.system("git clone https://github.com/bryandlee/animegan2-pytorch")
3
- os.system("gdown https://drive.google.com/uc?id=18H3iK09_d54qEDoWIc82SyWB2xun4gjU")
4
- import sys
5
- sys.path.append("animegan2-pytorch")
6
-
7
- import torch
8
- torch.set_grad_enabled(False)
9
-
10
- from model import Generator
11
-
12
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
-
14
- model = Generator().eval().to(device)
15
- model.load_state_dict(torch.load("face_paint_512_v2_0.pt"))
16
-
17
- from PIL import Image
18
- from torchvision.transforms.functional import to_tensor, to_pil_image
19
  import gradio as gr
20
-
21
- def face2paint(
22
- img: Image.Image,
23
- size: int,
24
- side_by_side: bool = False,
25
- ) -> Image.Image:
26
 
27
 
28
- input = to_tensor(img).unsqueeze(0) * 2 - 1
29
- output = model(input.to(device)).cpu()[0]
30
 
31
- if side_by_side:
32
- output = torch.cat([input[0], output], dim=2)
33
 
34
- output = (output * 0.5 + 0.5).clip(0, 1)
35
-
36
- return to_pil_image(output)
37
-
38
 
39
- import collections
40
- from typing import Union, List
41
  import numpy as np
42
  from PIL import Image
43
 
44
-
45
- import PIL.Image
46
- import PIL.ImageFile
47
- import numpy as np
48
- import scipy.ndimage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
 
51
- import requests
52
-
53
- def inference(img):
54
- out = face2paint(img, 512)
55
- return out
56
 
57
 
58
  title = "Animeganv2"
59
  description = "Gradio demo for AnimeGanv2 Face Portrait v2. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below. Please use a cropped portrait picture for best results similar to the examples below"
60
- article = "<p style='text-align: center'><a href='https://github.com/bryandlee/animegan2-pytorch' target='_blank'>Github Repo</a></p><p style='text-align: center'>samples from repo: <img src='https://user-images.githubusercontent.com/26464535/129888683-98bb6283-7bb8-4d1a-a04a-e795f5858dcf.gif' alt='animation'/> <img src='https://user-images.githubusercontent.com/26464535/137619176-59620b59-4e20-4d98-9559-a424f86b7f24.jpg' alt='animation'/></p>"
61
 
62
  examples=[['groot.jpeg'],['bill.png'],['tony.png'],['elon.png'],['IU.png'],['billie.png'],['will.png'],['beyonce.jpeg'],['gongyoo.jpeg']]
63
- gr.Interface(inference, gr.inputs.Image(type="pil",shape=(512,512)), gr.outputs.Image(type="pil"),title=title,description=description,article=article,examples=examples,enable_queue=True).launch()
 
1
+ import onnxruntime
2
+ print(onnxruntime.get_device())
3
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import gradio as gr
5
+ os.system("pip install gdown")
6
+ os.system("gdown https://drive.google.com/uc?id=1riNxV1BWMAXfmWZ3LrQbEkvzV8f7lOCp")
 
 
 
 
7
 
8
 
9
+ onnx_session = onnxruntime.InferenceSession("face_paint_512_v2_0.onnx")
 
10
 
11
+ input_name = onnx_session.get_inputs()[0].name
12
+ output_name = onnx_session.get_outputs()[0].name
13
 
14
+ side_length = 512
 
 
 
15
 
16
+ import cv2 as cv
 
17
  import numpy as np
18
  from PIL import Image
19
 
20
+ def inference(img):
21
+ image = np.array(img)
22
+ image = image[:, :, ::-1].copy()
23
+ image = cv.resize(image, dsize=(side_length, side_length))
24
+ x = cv.cvtColor(image, cv.COLOR_BGR2RGB)
25
+
26
+ x = np.array(x, dtype=np.float32)
27
+ x = x.transpose(2, 0, 1)
28
+ x = x * 2 - 1
29
+ x = x.reshape(-1, 3, side_length, side_length)
30
+
31
+ onnx_result = onnx_session.run([output_name], {input_name: x})
32
+
33
+ onnx_result = np.array(onnx_result).squeeze()
34
+ onnx_result = (onnx_result * 0.5 + 0.5).clip(0, 1)
35
+ onnx_result = onnx_result * 255
36
+
37
+ onnx_result = onnx_result.transpose(1, 2, 0).astype('uint8')
38
+ onnx_result = cv.cvtColor(onnx_result, cv.COLOR_RGB2BGR)
39
 
40
 
41
+ img = cv.cvtColor(onnx_result, cv.COLOR_BGR2RGB)
42
+ im_pil = Image.fromarray(img)
43
+ return im_pil
 
 
44
 
45
 
46
  title = "Animeganv2"
47
  description = "Gradio demo for AnimeGanv2 Face Portrait v2. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below. Please use a cropped portrait picture for best results similar to the examples below"
48
+ article = "<p style='text-align: center'><a href='https://github.com/bryandlee/animegan2-pytorch' target='_blank'>Github Repo Pytorch</a> | <a href='https://github.com/Kazuhito00/AnimeGANv2-ONNX-Sample' target='_blank'>Github Repo ONNX</a></p><p style='text-align: center'>samples from repo: <img src='https://user-images.githubusercontent.com/26464535/129888683-98bb6283-7bb8-4d1a-a04a-e795f5858dcf.gif' alt='animation'/> <img src='https://user-images.githubusercontent.com/26464535/137619176-59620b59-4e20-4d98-9559-a424f86b7f24.jpg' alt='animation'/></p>"
49
 
50
  examples=[['groot.jpeg'],['bill.png'],['tony.png'],['elon.png'],['IU.png'],['billie.png'],['will.png'],['beyonce.jpeg'],['gongyoo.jpeg']]
51
+ gr.Interface(inference, gr.inputs.Image(type="pil"), gr.outputs.Image(type="pil"),title=title,description=description,article=article,enable_queue=True,examples=examples).launch()