Tu Bui
commited on
Commit
•
17b1745
1
Parent(s):
e97a15c
fix cuda -> cpu
Browse files- Embed_Secret.py +3 -3
Embed_Secret.py
CHANGED
@@ -130,7 +130,7 @@ def embed_secret(model_name, model, cover, tform, secret):
|
|
130 |
if model_name == 'UNet':
|
131 |
w, h = cover.size
|
132 |
with torch.no_grad():
|
133 |
-
im = tform(cover).unsqueeze(0).
|
134 |
stego, _ = model(im, secret) # 1, 3, 256, 256
|
135 |
res = (stego.clamp(-1,1) - im) # (1,3,256,256) residual
|
136 |
res = torch.nn.functional.interpolate(res, (h,w), mode='bilinear')
|
@@ -147,7 +147,7 @@ def identity(x):
|
|
147 |
def decode_secret(model_name, model, im, tform):
|
148 |
if model_name in ['RoSteALS', 'UNet']:
|
149 |
with torch.no_grad():
|
150 |
-
im = tform(im).unsqueeze(0).
|
151 |
secret_pred = (model.decoder(im) > 0).cpu().numpy() # 1, 100
|
152 |
else:
|
153 |
raise NotImplementedError
|
@@ -241,7 +241,7 @@ def app(args):
|
|
241 |
])
|
242 |
if image_file is not None and secret_text is not None:
|
243 |
secret = ecc.encode_text([secret_text]) # (1, len)
|
244 |
-
secret = torch.from_numpy(secret).float().
|
245 |
# im = tform(im).unsqueeze(0).cuda() # (1,3,H,W)
|
246 |
stego = embed_secret(model_name, model, im, tform_emb, secret)
|
247 |
st.image(stego, width=display_width)
|
|
|
130 |
if model_name == 'UNet':
|
131 |
w, h = cover.size
|
132 |
with torch.no_grad():
|
133 |
+
im = tform(cover).unsqueeze(0).to(model.device) # 1, 3, 256, 256
|
134 |
stego, _ = model(im, secret) # 1, 3, 256, 256
|
135 |
res = (stego.clamp(-1,1) - im) # (1,3,256,256) residual
|
136 |
res = torch.nn.functional.interpolate(res, (h,w), mode='bilinear')
|
|
|
147 |
def decode_secret(model_name, model, im, tform):
|
148 |
if model_name in ['RoSteALS', 'UNet']:
|
149 |
with torch.no_grad():
|
150 |
+
im = tform(im).unsqueeze(0).to(model.device) # 1, 3, 256, 256
|
151 |
secret_pred = (model.decoder(im) > 0).cpu().numpy() # 1, 100
|
152 |
else:
|
153 |
raise NotImplementedError
|
|
|
241 |
])
|
242 |
if image_file is not None and secret_text is not None:
|
243 |
secret = ecc.encode_text([secret_text]) # (1, len)
|
244 |
+
secret = torch.from_numpy(secret).float().to(model.device)
|
245 |
# im = tform(im).unsqueeze(0).cuda() # (1,3,H,W)
|
246 |
stego = embed_secret(model_name, model, im, tform_emb, secret)
|
247 |
st.image(stego, width=display_width)
|