秋山翔 commited on
Commit
8941262
1 Parent(s): 449b632

FIX: limit image size to avoid exceeding CPU limit

Browse files
Files changed (1) hide show
  1. app.py +44 -30
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import torch
3
  import gradio as gr
4
  import numpy as np
@@ -12,7 +13,7 @@ import logging
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
- LOAD_SIZE = 1280
16
  MODEL_PATH = "models"
17
  COLOUR_MODEL = "RGB"
18
 
@@ -66,33 +67,40 @@ def get_model(style):
66
  return shinkai_model
67
 
68
 
 
 
 
 
 
 
 
 
69
  def inference(img, style):
70
- try:
71
- # load image
72
- input_image = img.convert(COLOUR_MODEL)
73
- input_image = np.asarray(input_image)
74
- # RGB -> BGR
75
- input_image = input_image[:, :, [2, 1, 0]]
76
- input_image = transforms.ToTensor()(input_image).unsqueeze(0)
77
- # preprocess, (-1, 1)
78
- input_image = -1 + 2 * input_image
79
-
80
- if disable_gpu:
81
- input_image = Variable(input_image).float()
82
- else:
83
- input_image = Variable(input_image).cuda()
84
-
85
- # forward
86
- model = get_model(style)
87
- output_image = model(input_image)
88
- output_image = output_image[0]
89
- # BGR -> RGB
90
- output_image = output_image[[2, 1, 0], :, :]
91
- output_image = output_image.data.cpu().float() * 0.5 + 0.5
92
-
93
- return transforms.ToPILImage()(output_image)
94
- except:
95
- logger.error(f"Error while processing image {img}")
96
 
97
 
98
  title = "Anime Background GAN"
@@ -108,7 +116,10 @@ examples = [
108
  gr.Interface(
109
  fn=inference,
110
  inputs=[
111
- gr.inputs.Image(type="pil", label="Input Photo"),
 
 
 
112
  gr.inputs.Dropdown(
113
  STYLE_CHOICE_LIST,
114
  type="value",
@@ -116,11 +127,14 @@ gr.Interface(
116
  label="Style",
117
  ),
118
  ],
119
- outputs=gr.outputs.Image(type="pil"),
 
 
 
120
  title=title,
121
  description=description,
122
  article=article,
123
  examples=examples,
124
  allow_flagging="never",
125
  allow_screenshot=False,
126
- ).launch(enable_queue=True, share=True)
 
1
  import os
2
+ import sys
3
  import torch
4
  import gradio as gr
5
  import numpy as np
 
13
 
14
  logger = logging.getLogger(__name__)
15
 
16
+ MAX_DIMENSION = 1280
17
  MODEL_PATH = "models"
18
  COLOUR_MODEL = "RGB"
19
 
 
67
  return shinkai_model
68
 
69
 
70
+ def validate_image_size(img):
71
+ print(f"{img.height} x {img.width}")
72
+ if img.height > MAX_DIMENSION or img.width > MAX_DIMENSION:
73
+ raise RuntimeError(
74
+ "Image size is too large. Please use an image less than {MAX_DIMENSION}px on both width and height"
75
+ )
76
+
77
+
78
  def inference(img, style):
79
+ validate_image_size(img)
80
+
81
+ # load image
82
+ input_image = img.convert(COLOUR_MODEL)
83
+ input_image = np.asarray(input_image)
84
+ # RGB -> BGR
85
+ input_image = input_image[:, :, [2, 1, 0]]
86
+ input_image = transforms.ToTensor()(input_image).unsqueeze(0)
87
+ # preprocess, (-1, 1)
88
+ input_image = -1 + 2 * input_image
89
+
90
+ if disable_gpu:
91
+ input_image = Variable(input_image).float()
92
+ else:
93
+ input_image = Variable(input_image).cuda()
94
+
95
+ # forward
96
+ model = get_model(style)
97
+ output_image = model(input_image)
98
+ output_image = output_image[0]
99
+ # BGR -> RGB
100
+ output_image = output_image[[2, 1, 0], :, :]
101
+ output_image = output_image.data.cpu().float() * 0.5 + 0.5
102
+
103
+ return transforms.ToPILImage()(output_image)
 
104
 
105
 
106
  title = "Anime Background GAN"
 
116
  gr.Interface(
117
  fn=inference,
118
  inputs=[
119
+ gr.inputs.Image(
120
+ type="pil",
121
+ label="Input Photo (less than 1280px on both width and height)",
122
+ ),
123
  gr.inputs.Dropdown(
124
  STYLE_CHOICE_LIST,
125
  type="value",
 
127
  label="Style",
128
  ),
129
  ],
130
+ outputs=gr.outputs.Image(
131
+ type="pil",
132
+ label="Make sure to resize to less than 1280px on both width and height if an error occurrs!",
133
+ ),
134
  title=title,
135
  description=description,
136
  article=article,
137
  examples=examples,
138
  allow_flagging="never",
139
  allow_screenshot=False,
140
+ ).launch(enable_queue=True)