menghanxia commited on
Commit
40d12a9
1 Parent(s): aa8edf3

fixed checkpoint loading requires GPU issue

Browse files
Files changed (3) hide show
  1. __pycache__/inference.cpython-39.pyc +0 -0
  2. app.py +4 -4
  3. inference.py +10 -3
__pycache__/inference.cpython-39.pyc ADDED
Binary file (3.3 kB). View file
 
app.py CHANGED
@@ -19,7 +19,7 @@ if RUN_MODE != "local":
19
  ## step 1: set up model
20
  device = "cpu"
21
  checkpt_path = "checkpoints/model_best.pth.tar"
22
- invhalfer = Inferencer(checkpoint_path=checkpt_path, model=ResHalf(train=False), use_cuda=False)
23
 
24
 
25
  def prepare_data(input_img, decoding_only=False):
@@ -41,7 +41,7 @@ def run_invhalf(invhalfer, input_img, decoding_only, device="cuda"):
41
  print('>>>:halftoning mode')
42
  resHalftone, resColor = invhalfer(input_img, decoding_only=decoding_only)
43
  output = util.tensor2img(resHalftone / 2. + 0.5) * 255.
44
- return (output+0.5).astype(np.uint8)
45
 
46
 
47
  def click_run(input_img, decoding_only):
@@ -66,7 +66,7 @@ with demo:
66
 
67
  Button_run.click(fn=click_run, inputs=[Image_input, Radio_mode], outputs=Image_output)
68
 
69
- if RUN_MODE == "local":
70
  gr.Examples(examples=[
71
  ['girl.png', "Halftoning (Photo2Halftone)"],
72
  ['wave.png', "Halftoning (Photo2Halftone)"],
@@ -74,7 +74,7 @@ with demo:
74
  ],
75
  inputs=[Image_input,Radio_mode], outputs=[Image_output], label="Examples")
76
 
77
- if RUN_MODE != "local":
78
  demo.launch(server_name='9.134.253.83',server_port=7788)
79
  else:
80
  demo.launch()
 
19
  ## step 1: set up model
20
  device = "cpu"
21
  checkpt_path = "checkpoints/model_best.pth.tar"
22
+ invhalfer = Inferencer(checkpoint_path=checkpt_path, model=ResHalf(train=False), use_cuda=False, multi_gpu=False)
23
 
24
 
25
  def prepare_data(input_img, decoding_only=False):
 
41
  print('>>>:halftoning mode')
42
  resHalftone, resColor = invhalfer(input_img, decoding_only=decoding_only)
43
  output = util.tensor2img(resHalftone / 2. + 0.5) * 255.
44
+ return np.clip(output, 0, 255).astype(np.uint8)
45
 
46
 
47
  def click_run(input_img, decoding_only):
 
66
 
67
  Button_run.click(fn=click_run, inputs=[Image_input, Radio_mode], outputs=Image_output)
68
 
69
+ if RUN_MODE != "local":
70
  gr.Examples(examples=[
71
  ['girl.png', "Halftoning (Photo2Halftone)"],
72
  ['wave.png', "Halftoning (Photo2Halftone)"],
 
74
  ],
75
  inputs=[Image_input,Radio_mode], outputs=[Image_output], label="Examples")
76
 
77
+ if RUN_MODE == "local":
78
  demo.launch(server_name='9.134.253.83',server_port=7788)
79
  else:
80
  demo.launch()
inference.py CHANGED
@@ -13,18 +13,25 @@ from model.loss import l1_loss
13
  from utils import util
14
  from utils.dct import DCT_Lowfrequency
15
  from utils.filters_tensor import bgr2gray
16
-
17
 
18
  class Inferencer:
19
  def __init__(self, checkpoint_path, model, use_cuda=True, multi_gpu=True):
20
- self.checkpoint = torch.load(checkpoint_path)
21
  self.use_cuda = use_cuda
22
  self.model = model.eval()
23
  if multi_gpu:
24
  self.model = torch.nn.DataParallel(self.model)
 
 
 
 
 
 
 
25
  if self.use_cuda:
26
  self.model = self.model.cuda()
27
- self.model.load_state_dict(self.checkpoint['state_dict'])
28
 
29
  def __call__(self, input_img, decoding_only=False):
30
  with torch.no_grad():
 
13
  from utils import util
14
  from utils.dct import DCT_Lowfrequency
15
  from utils.filters_tensor import bgr2gray
16
+ from collections import OrderedDict
17
 
18
  class Inferencer:
19
  def __init__(self, checkpoint_path, model, use_cuda=True, multi_gpu=True):
20
+ self.checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
21
  self.use_cuda = use_cuda
22
  self.model = model.eval()
23
  if multi_gpu:
24
  self.model = torch.nn.DataParallel(self.model)
25
+ state_dict = self.checkpoint['state_dict']
26
+ else:
27
+ ## remove keyword "module" in the state_dict
28
+ state_dict = OrderedDict()
29
+ for k, v in self.checkpoint['state_dict'].items():
30
+ name = k[7:]
31
+ state_dict[name] = v
32
  if self.use_cuda:
33
  self.model = self.model.cuda()
34
+ self.model.load_state_dict(state_dict)
35
 
36
  def __call__(self, input_img, decoding_only=False):
37
  with torch.no_grad():