Tu Bui commited on
Commit
17b1745
1 Parent(s): e97a15c

fix cuda -> cpu

Browse files
Files changed (1) hide show
  1. Embed_Secret.py +3 -3
Embed_Secret.py CHANGED
@@ -130,7 +130,7 @@ def embed_secret(model_name, model, cover, tform, secret):
130
  if model_name == 'UNet':
131
  w, h = cover.size
132
  with torch.no_grad():
133
- im = tform(cover).unsqueeze(0).cuda() # 1, 3, 256, 256
134
  stego, _ = model(im, secret) # 1, 3, 256, 256
135
  res = (stego.clamp(-1,1) - im) # (1,3,256,256) residual
136
  res = torch.nn.functional.interpolate(res, (h,w), mode='bilinear')
@@ -147,7 +147,7 @@ def identity(x):
147
  def decode_secret(model_name, model, im, tform):
148
  if model_name in ['RoSteALS', 'UNet']:
149
  with torch.no_grad():
150
- im = tform(im).unsqueeze(0).cuda() # 1, 3, 256, 256
151
  secret_pred = (model.decoder(im) > 0).cpu().numpy() # 1, 100
152
  else:
153
  raise NotImplementedError
@@ -241,7 +241,7 @@ def app(args):
241
  ])
242
  if image_file is not None and secret_text is not None:
243
  secret = ecc.encode_text([secret_text]) # (1, len)
244
- secret = torch.from_numpy(secret).float().cuda()
245
  # im = tform(im).unsqueeze(0).cuda() # (1,3,H,W)
246
  stego = embed_secret(model_name, model, im, tform_emb, secret)
247
  st.image(stego, width=display_width)
 
130
  if model_name == 'UNet':
131
  w, h = cover.size
132
  with torch.no_grad():
133
+ im = tform(cover).unsqueeze(0).to(model.device) # 1, 3, 256, 256
134
  stego, _ = model(im, secret) # 1, 3, 256, 256
135
  res = (stego.clamp(-1,1) - im) # (1,3,256,256) residual
136
  res = torch.nn.functional.interpolate(res, (h,w), mode='bilinear')
 
147
  def decode_secret(model_name, model, im, tform):
148
  if model_name in ['RoSteALS', 'UNet']:
149
  with torch.no_grad():
150
+ im = tform(im).unsqueeze(0).to(model.device) # 1, 3, 256, 256
151
  secret_pred = (model.decoder(im) > 0).cpu().numpy() # 1, 100
152
  else:
153
  raise NotImplementedError
 
241
  ])
242
  if image_file is not None and secret_text is not None:
243
  secret = ecc.encode_text([secret_text]) # (1, len)
244
+ secret = torch.from_numpy(secret).float().to(model.device)
245
  # im = tform(im).unsqueeze(0).cuda() # (1,3,H,W)
246
  stego = embed_secret(model_name, model, im, tform_emb, secret)
247
  st.image(stego, width=display_width)