NDugar commited on
Commit
6ed6d81
1 Parent(s): d68506b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -0
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cgitb import enable
2
+ from ctypes.wintypes import HFONT
3
+ import os
4
+ import sys
5
+ import torch
6
+ import gradio as gr
7
+ import numpy as np
8
+ import torchvision.transforms as transforms
9
+
10
+
11
+ from torch.autograd import Variable
12
+ from network.Transformer import Transformer
13
+ from huggingface_hub import hf_hub_download
14
+
15
+ from PIL import Image
16
+
17
+ import logging
18
+
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Constants
23
+
24
+ MODEL_PATH = "models"
25
+ COLOUR_MODEL = "RGB"
26
+
27
+ MODEL_REPO = "NDugar/horse_to_zebra_cycle_GAN"
28
+ MODEL_FILE = "h2z-85epoch.pth"
29
+
30
+ # Model Initalisation
31
+ #shinkai_model_hfhub = hf_hub_download(repo_id=MODEL_REPO_SHINKAI, filename=MODEL_FILE_SHINKAI)
32
+ #hosoda_model_hfhub = hf_hub_download(repo_id=MODEL_REPO_HOSODA, filename=MODEL_FILE_HOSODA)
33
+ #miyazaki_model_hfhub = hf_hub_download(repo_id=MODEL_REPO_MIYAZAKI, filename=MODEL_FILE_MIYAZAKI)
34
+ model_hfhub = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
35
+
36
+ #shinkai_model = Transformer()
37
+ #hosoda_model = Transformer()
38
+ #miyazaki_model = Transformer()
39
+ model = Transformer()
40
+
41
+ enable_gpu = torch.cuda.is_available()
42
+ map_location = torch.device("cuda") if enable_gpu else "cpu"
43
+
44
+ model.load_state_dict(torch.load(model_hfhub, map_location=map_location))
45
+
46
+ shinkai_model.eval()
47
+ hosoda_model.eval()
48
+ miyazaki_model.eval()
49
+ kon_model.eval()
50
+
51
+
52
+ # Functions
53
+
54
+ def get_model():
55
+ return model
56
+
57
+
58
+ def adjust_image_for_model(img):
59
+ logger.info(f"Image Height: {img.height}, Image Width: {img.width}")
60
+ return img
61
+
62
+
63
+ def inference(img, style):
64
+ img = adjust_image_for_model(img)
65
+ input_image = img.convert(COLOUR_MODEL)
66
+ input_image = np.asarray(input_image)
67
+ input_image = input_image[:, :, [2, 1, 0]]
68
+ input_image = transforms.ToTensor()(input_image).unsqueeze(0)
69
+ input_image = -1 + 2 * input_image
70
+
71
+ if enable_gpu:
72
+ logger.info(f"CUDA found. Using GPU.")
73
+ input_image = Variable(input_image).cuda()
74
+ else:
75
+ logger.info(f"CUDA not found. Using CPU.")
76
+ input_image = Variable(input_image).float()
77
+ model = get_model()
78
+ output_image = model(input_image)
79
+ output_image = output_image[0]
80
+ # BGR -> RGB
81
+ output_image = output_image[[2, 1, 0], :, :]
82
+ output_image = output_image.data.cpu().float() * 0.5 + 0.5
83
+
84
+ return transforms.ToPILImage()(output_image)
85
+
86
+
87
+ # Gradio setup
88
+
89
+ title = "Horse 2 Zebra GAN"
90
+ description = "Gradio Demo for CycleGAN"
91
+
92
+ gr.Interface(
93
+ fn=inference,
94
+ inputs=[
95
+ gr.inputs.Image(
96
+ type="pil",
97
+ label="Input Photo",
98
+ ),
99
+ ],
100
+ outputs=gr.outputs.Image(
101
+ type="pil",
102
+ label="Output Image",
103
+ ),
104
+ title=title,
105
+ description=description,
106
+ article=article,
107
+ examples=examples,
108
+ allow_flagging="never",
109
+ allow_screenshot=False,
110
+ ).launch(enable_queue=True)