akiyamasho commited on
Commit
76b1411
1 Parent(s): 62b0925

MAINT: use model repos

Browse files
Files changed (1) hide show
  1. app.py +33 -6
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import os
2
  import sys
3
  import torch
@@ -8,6 +10,7 @@ import torchvision.transforms as transforms
8
 
9
  from torch.autograd import Variable
10
  from network.Transformer import Transformer
 
11
 
12
  from PIL import Image
13
 
@@ -16,6 +19,8 @@ import logging
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
 
 
19
  MAX_DIMENSION = 1280
20
  MODEL_PATH = "models"
21
  COLOUR_MODEL = "RGB"
@@ -27,23 +32,43 @@ STYLE_KON = "Satoshi Kon"
27
  DEFAULT_STYLE = STYLE_SHINKAI
28
  STYLE_CHOICE_LIST = [STYLE_SHINKAI, STYLE_HOSODA, STYLE_MIYAZAKI, STYLE_KON]
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  shinkai_model = Transformer()
31
  hosoda_model = Transformer()
32
  miyazaki_model = Transformer()
33
  kon_model = Transformer()
34
 
 
 
35
 
36
  shinkai_model.load_state_dict(
37
- torch.load(os.path.join(MODEL_PATH, "shinkai_makoto.pth"))
38
  )
39
  hosoda_model.load_state_dict(
40
- torch.load(os.path.join(MODEL_PATH, "hosoda_mamoru.pth"))
41
  )
42
  miyazaki_model.load_state_dict(
43
- torch.load(os.path.join(MODEL_PATH, "miyazaki_hayao.pth"))
44
  )
45
  kon_model.load_state_dict(
46
- torch.load(os.path.join(MODEL_PATH, "kon_satoshi.pth"))
47
  )
48
 
49
  shinkai_model.eval()
@@ -51,8 +76,8 @@ hosoda_model.eval()
51
  miyazaki_model.eval()
52
  kon_model.eval()
53
 
54
- enable_gpu = torch.cuda.is_available()
55
 
 
56
 
57
  def get_model(style):
58
  if style == STYLE_SHINKAI:
@@ -109,9 +134,11 @@ def inference(img, style):
109
  return transforms.ToPILImage()(output_image)
110
 
111
 
 
 
112
  title = "Anime Background GAN"
113
  description = "Gradio Demo for CartoonGAN by Chen Et. Al. Models are Shinkai Makoto, Hosoda Mamoru, Kon Satoshi, and Miyazaki Hayao."
114
- article = "<p style='text-align: center'><a href='http://openaccess.thecvf.com/content_cvpr_2018/CameraReady/2205.pdf' target='_blank'>CartoonGAN Whitepaper from Chen et.al</a></p><p style='text-align: center'><a href='https://github.com/venture-anime/anime-background-gan-hf-space' target='_blank'>Spaces Github Repo</a></p><p style='text-align: center'><a href='https://github.com/venture-anime/cartoongan-pytorch' target='_blank'>Local Run GitHub Repo</a></p><p style='text-align: center'><a href='https://github.com/Yijunmaverick/CartoonGAN-Test-Pytorch-Torch' target='_blank'>Original Implementation from Yijunmaverick</a></p><center><img src='https://visitor-badge.glitch.me/badge?page_id=akiyamasho' alt='visitor badge'></center></p>"
115
 
116
  examples = [
117
  ["examples/garden_in.jpg", STYLE_SHINKAI],
1
+ from cgitb import enable
2
+ from ctypes.wintypes import HFONT
3
  import os
4
  import sys
5
  import torch
10
 
11
  from torch.autograd import Variable
12
  from network.Transformer import Transformer
13
+ from huggingface_hub import hf_hub_download
14
 
15
  from PIL import Image
16
 
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
22
+ # Constants
23
+
24
  MAX_DIMENSION = 1280
25
  MODEL_PATH = "models"
26
  COLOUR_MODEL = "RGB"
32
  DEFAULT_STYLE = STYLE_SHINKAI
33
  STYLE_CHOICE_LIST = [STYLE_SHINKAI, STYLE_HOSODA, STYLE_MIYAZAKI, STYLE_KON]
34
 
35
+ MODEL_REPO_SHINKAI = "akiyamasho/AnimeBackgroundGAN-Shinkai"
36
+ MODEL_FILE_SHINKAI = "shinkai_makoto.pth"
37
+
38
+ MODEL_REPO_HOSODA = "akiyamasho/AnimeBackgroundGAN-Hosoda"
39
+ MODEL_FILE_HOSODA = "hosoda_mamoru.pth"
40
+
41
+ MODEL_REPO_MIYAZAKI = "akiyamasho/AnimeBackgroundGAN-Miyazaki"
42
+ MODEL_FILE_MIYAZAKI = "miyazaki_hayao.pth"
43
+
44
+ MODEL_REPO_KON = "akiyamasho/AnimeBackgroundGAN-Kon"
45
+ MODEL_FILE_KON = "kon_satoshi.pth"
46
+
47
+ # Model Initalisation
48
+ shinkai_model_hfhub = hf_hub_download(repo_id=MODEL_REPO_SHINKAI, filename=MODEL_FILE_SHINKAI)
49
+ hosoda_model_hfhub = hf_hub_download(repo_id=MODEL_REPO_HOSODA, filename=MODEL_FILE_HOSODA)
50
+ miyazaki_model_hfhub = hf_hub_download(repo_id=MODEL_REPO_MIYAZAKI, filename=MODEL_FILE_MIYAZAKI)
51
+ kon_model_hfhub = hf_hub_download(repo_id=MODEL_REPO_KON, filename=MODEL_FILE_KON)
52
+
53
  shinkai_model = Transformer()
54
  hosoda_model = Transformer()
55
  miyazaki_model = Transformer()
56
  kon_model = Transformer()
57
 
58
+ enable_gpu = torch.cuda.is_available()
59
+ map_location = torch.device("cuda") if enable_gpu else "cpu"
60
 
61
  shinkai_model.load_state_dict(
62
+ torch.load(shinkai_model_hfhub, map_location=map_location)
63
  )
64
  hosoda_model.load_state_dict(
65
+ torch.load(hosoda_model_hfhub, map_location=map_location)
66
  )
67
  miyazaki_model.load_state_dict(
68
+ torch.load(miyazaki_model_hfhub, map_location=map_location)
69
  )
70
  kon_model.load_state_dict(
71
+ torch.load(kon_model_hfhub, map_location=map_location)
72
  )
73
 
74
  shinkai_model.eval()
76
  miyazaki_model.eval()
77
  kon_model.eval()
78
 
 
79
 
80
+ # Functions
81
 
82
  def get_model(style):
83
  if style == STYLE_SHINKAI:
134
  return transforms.ToPILImage()(output_image)
135
 
136
 
137
+ # Gradio setup
138
+
139
  title = "Anime Background GAN"
140
  description = "Gradio Demo for CartoonGAN by Chen Et. Al. Models are Shinkai Makoto, Hosoda Mamoru, Kon Satoshi, and Miyazaki Hayao."
141
+ article = "<p style='text-align: center'><a href='http://openaccess.thecvf.com/content_cvpr_2018/CameraReady/2205.pdf' target='_blank'>CartoonGAN Whitepaper from Chen et.al</a></p><p style='text-align: center'><a href='https://github.com/venture-anime/cartoongan-pytorch' target='_blank'>Github Repo</a></p><p style='text-align: center'><a href='https://github.com/Yijunmaverick/CartoonGAN-Test-Pytorch-Torch' target='_blank'>Original Implementation from Yijunmaverick</a></p><center><img src='https://visitor-badge.glitch.me/badge?page_id=akiyamasho' alt='visitor badge'></center></p>"
142
 
143
  examples = [
144
  ["examples/garden_in.jpg", STYLE_SHINKAI],