Tu Bui
commited on
Commit
•
04acf84
1
Parent(s):
f94afcc
fix input arg name
Browse files- Embed_Secret.py +2 -2
- cldm/transformations2.py +6 -5
- pages/Extract_Secret.py +8 -7
Embed_Secret.py
CHANGED
@@ -91,8 +91,8 @@ def load_UNet(args):
|
|
91 |
# config_file = '/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/configs/-project.yaml'
|
92 |
# weight_file = '/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/checkpoints/epoch=000070-step=000219999.ckpt'
|
93 |
|
94 |
-
config_file = args.
|
95 |
-
weight_file = args.
|
96 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
97 |
if weight_file.startswith('http'): # download from url
|
98 |
weight_dir = Path('./weights')
|
|
|
91 |
# config_file = '/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/configs/-project.yaml'
|
92 |
# weight_file = '/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/checkpoints/epoch=000070-step=000219999.ckpt'
|
93 |
|
94 |
+
config_file = args.config
|
95 |
+
weight_file = args.weight
|
96 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
97 |
if weight_file.startswith('http'): # download from url
|
98 |
weight_dir = Path('./weights')
|
cldm/transformations2.py
CHANGED
@@ -367,7 +367,7 @@ class TransformNet(nn.Module):
|
|
367 |
def apply_transform_on_pil_image(self, x, tform_name):
|
368 |
# x: PIL image
|
369 |
# return: PIL image
|
370 |
-
assert tform_name in self.optional_names + ['
|
371 |
# if tform_name == 'Random Crop': # the only transform dependent on image size
|
372 |
# # crop equivalent to 224/256
|
373 |
# w, h = x.size
|
@@ -392,10 +392,11 @@ class TransformNet(nn.Module):
|
|
392 |
x = x.resize((256, 256), Image.BILINEAR)
|
393 |
x = np.array(x).astype(np.float32) / 255. # [0, 255] -> [0, 1]
|
394 |
x = torch.from_numpy(x).permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
|
395 |
-
if tform_name == '
|
396 |
-
|
397 |
-
|
398 |
-
|
|
|
399 |
else:
|
400 |
tform_id = self.optional_names.index(tform_name)
|
401 |
tform = self.optional_transforms[tform_id]
|
|
|
367 |
def apply_transform_on_pil_image(self, x, tform_name):
|
368 |
# x: PIL image
|
369 |
# return: PIL image
|
370 |
+
assert tform_name in self.optional_names + ['Fixed Augment']
|
371 |
# if tform_name == 'Random Crop': # the only transform dependent on image size
|
372 |
# # crop equivalent to 224/256
|
373 |
# w, h = x.size
|
|
|
392 |
x = x.resize((256, 256), Image.BILINEAR)
|
393 |
x = np.array(x).astype(np.float32) / 255. # [0, 255] -> [0, 1]
|
394 |
x = torch.from_numpy(x).permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
|
395 |
+
if tform_name == 'Fixed Augment':
|
396 |
+
for tform in self.fixed_transforms:
|
397 |
+
x = tform(x)
|
398 |
+
if isinstance(x, tuple):
|
399 |
+
x = x[0]
|
400 |
else:
|
401 |
tform_id = self.optional_names.index(tform_name)
|
402 |
tform = self.optional_transforms[tform_id]
|
pages/Extract_Secret.py
CHANGED
@@ -27,17 +27,17 @@ from io import BytesIO
|
|
27 |
from tools.helpers import welcome_message
|
28 |
from tools.ecc import BCH, RSC
|
29 |
import streamlit as st
|
30 |
-
from Embed_Secret import load_ecc, load_model, decode_secret, to_bytes, model_names, SECRET_LEN
|
31 |
|
32 |
|
33 |
# model_names = ['RoSteALS', 'UNet']
|
34 |
# SECRET_LEN = 100
|
35 |
|
36 |
-
def app():
|
37 |
st.title('Watermarking Demo')
|
38 |
# setup model
|
39 |
model_name = st.selectbox("Choose the model", model_names)
|
40 |
-
model, tform_emb, tform_det = load_model(model_name)
|
41 |
display_width = 300
|
42 |
ecc = load_ecc('BCH')
|
43 |
noise = TransformNet(p=1.0, crop_mode='resized_crop')
|
@@ -55,11 +55,11 @@ def app():
|
|
55 |
|
56 |
# add crop
|
57 |
st.subheader("Corruptions")
|
58 |
-
crop_button = st.button('Regenerate Crop', key='crop')
|
59 |
if image_file is not None:
|
60 |
-
im_crop = noise.apply_transform_on_pil_image(im, '
|
61 |
if crop_button:
|
62 |
-
im_crop = noise.apply_transform_on_pil_image(im, '
|
63 |
# st.image(im_crop, width=display_width)
|
64 |
|
65 |
# add noise source 1
|
@@ -106,5 +106,6 @@ def app():
|
|
106 |
bit_acc_status.markdown(f'Bit Accuracy: **{bit_acc*100:.2f}%**<br />Word Accuracy: **{word_acc}**', unsafe_allow_html=True)
|
107 |
|
108 |
if __name__ == '__main__':
|
109 |
-
|
|
|
110 |
|
|
|
27 |
from tools.helpers import welcome_message
|
28 |
from tools.ecc import BCH, RSC
|
29 |
import streamlit as st
|
30 |
+
from Embed_Secret import parse_st_args, load_ecc, load_model, decode_secret, to_bytes, model_names, SECRET_LEN
|
31 |
|
32 |
|
33 |
# model_names = ['RoSteALS', 'UNet']
|
34 |
# SECRET_LEN = 100
|
35 |
|
36 |
+
def app(args):
|
37 |
st.title('Watermarking Demo')
|
38 |
# setup model
|
39 |
model_name = st.selectbox("Choose the model", model_names)
|
40 |
+
model, tform_emb, tform_det = load_model(model_name, args)
|
41 |
display_width = 300
|
42 |
ecc = load_ecc('BCH')
|
43 |
noise = TransformNet(p=1.0, crop_mode='resized_crop')
|
|
|
55 |
|
56 |
# add crop
|
57 |
st.subheader("Corruptions")
|
58 |
+
crop_button = st.button('Regenerate Crop/Flip/Resize', key='crop')
|
59 |
if image_file is not None:
|
60 |
+
im_crop = noise.apply_transform_on_pil_image(im, 'Fixed Augment')
|
61 |
if crop_button:
|
62 |
+
im_crop = noise.apply_transform_on_pil_image(im, 'Fixed Augment')
|
63 |
# st.image(im_crop, width=display_width)
|
64 |
|
65 |
# add noise source 1
|
|
|
106 |
bit_acc_status.markdown(f'Bit Accuracy: **{bit_acc*100:.2f}%**<br />Word Accuracy: **{word_acc}**', unsafe_allow_html=True)
|
107 |
|
108 |
if __name__ == '__main__':
|
109 |
+
args = parse_st_args()
|
110 |
+
app(args)
|
111 |
|