Tu Bui commited on
Commit
04acf84
1 Parent(s): f94afcc

fix input arg name

Browse files
Embed_Secret.py CHANGED
@@ -91,8 +91,8 @@ def load_UNet(args):
91
  # config_file = '/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/configs/-project.yaml'
92
  # weight_file = '/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/checkpoints/epoch=000070-step=000219999.ckpt'
93
 
94
- config_file = args.config_file
95
- weight_file = args.weight_file
96
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
97
  if weight_file.startswith('http'): # download from url
98
  weight_dir = Path('./weights')
 
91
  # config_file = '/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/configs/-project.yaml'
92
  # weight_file = '/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/checkpoints/epoch=000070-step=000219999.ckpt'
93
 
94
+ config_file = args.config
95
+ weight_file = args.weight
96
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
97
  if weight_file.startswith('http'): # download from url
98
  weight_dir = Path('./weights')
cldm/transformations2.py CHANGED
@@ -367,7 +367,7 @@ class TransformNet(nn.Module):
367
  def apply_transform_on_pil_image(self, x, tform_name):
368
  # x: PIL image
369
  # return: PIL image
370
- assert tform_name in self.optional_names + ['Random Crop', 'Random Flip']
371
  # if tform_name == 'Random Crop': # the only transform dependent on image size
372
  # # crop equivalent to 224/256
373
  # w, h = x.size
@@ -392,10 +392,11 @@ class TransformNet(nn.Module):
392
  x = x.resize((256, 256), Image.BILINEAR)
393
  x = np.array(x).astype(np.float32) / 255. # [0, 255] -> [0, 1]
394
  x = torch.from_numpy(x).permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
395
- if tform_name == 'Random Flip':
396
- x = self.fixed_transforms[0](x)
397
- elif tform_name == 'Random Crop':
398
- x = self.fixed_transforms[1](x)
 
399
  else:
400
  tform_id = self.optional_names.index(tform_name)
401
  tform = self.optional_transforms[tform_id]
 
367
  def apply_transform_on_pil_image(self, x, tform_name):
368
  # x: PIL image
369
  # return: PIL image
370
+ assert tform_name in self.optional_names + ['Fixed Augment']
371
  # if tform_name == 'Random Crop': # the only transform dependent on image size
372
  # # crop equivalent to 224/256
373
  # w, h = x.size
 
392
  x = x.resize((256, 256), Image.BILINEAR)
393
  x = np.array(x).astype(np.float32) / 255. # [0, 255] -> [0, 1]
394
  x = torch.from_numpy(x).permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
395
+ if tform_name == 'Fixed Augment':
396
+ for tform in self.fixed_transforms:
397
+ x = tform(x)
398
+ if isinstance(x, tuple):
399
+ x = x[0]
400
  else:
401
  tform_id = self.optional_names.index(tform_name)
402
  tform = self.optional_transforms[tform_id]
pages/Extract_Secret.py CHANGED
@@ -27,17 +27,17 @@ from io import BytesIO
27
  from tools.helpers import welcome_message
28
  from tools.ecc import BCH, RSC
29
  import streamlit as st
30
- from Embed_Secret import load_ecc, load_model, decode_secret, to_bytes, model_names, SECRET_LEN
31
 
32
 
33
  # model_names = ['RoSteALS', 'UNet']
34
  # SECRET_LEN = 100
35
 
36
- def app():
37
  st.title('Watermarking Demo')
38
  # setup model
39
  model_name = st.selectbox("Choose the model", model_names)
40
- model, tform_emb, tform_det = load_model(model_name)
41
  display_width = 300
42
  ecc = load_ecc('BCH')
43
  noise = TransformNet(p=1.0, crop_mode='resized_crop')
@@ -55,11 +55,11 @@ def app():
55
 
56
  # add crop
57
  st.subheader("Corruptions")
58
- crop_button = st.button('Regenerate Crop', key='crop')
59
  if image_file is not None:
60
- im_crop = noise.apply_transform_on_pil_image(im, 'Random Crop')
61
  if crop_button:
62
- im_crop = noise.apply_transform_on_pil_image(im, 'Random Crop')
63
  # st.image(im_crop, width=display_width)
64
 
65
  # add noise source 1
@@ -106,5 +106,6 @@ def app():
106
  bit_acc_status.markdown(f'Bit Accuracy: **{bit_acc*100:.2f}%**<br />Word Accuracy: **{word_acc}**', unsafe_allow_html=True)
107
 
108
  if __name__ == '__main__':
109
- app()
 
110
 
 
27
  from tools.helpers import welcome_message
28
  from tools.ecc import BCH, RSC
29
  import streamlit as st
30
+ from Embed_Secret import parse_st_args, load_ecc, load_model, decode_secret, to_bytes, model_names, SECRET_LEN
31
 
32
 
33
  # model_names = ['RoSteALS', 'UNet']
34
  # SECRET_LEN = 100
35
 
36
+ def app(args):
37
  st.title('Watermarking Demo')
38
  # setup model
39
  model_name = st.selectbox("Choose the model", model_names)
40
+ model, tform_emb, tform_det = load_model(model_name, args)
41
  display_width = 300
42
  ecc = load_ecc('BCH')
43
  noise = TransformNet(p=1.0, crop_mode='resized_crop')
 
55
 
56
  # add crop
57
  st.subheader("Corruptions")
58
+ crop_button = st.button('Regenerate Crop/Flip/Resize', key='crop')
59
  if image_file is not None:
60
+ im_crop = noise.apply_transform_on_pil_image(im, 'Fixed Augment')
61
  if crop_button:
62
+ im_crop = noise.apply_transform_on_pil_image(im, 'Fixed Augment')
63
  # st.image(im_crop, width=display_width)
64
 
65
  # add noise source 1
 
106
  bit_acc_status.markdown(f'Bit Accuracy: **{bit_acc*100:.2f}%**<br />Word Accuracy: **{word_acc}**', unsafe_allow_html=True)
107
 
108
  if __name__ == '__main__':
109
+ args = parse_st_args()
110
+ app(args)
111