test / Embed_Secret.py
Tu Bui
first commit
6142a25
raw
history blame
9.49 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
streamlit app demo
how to run:
streamlit run app.py --server.port 8501
@author: Tu Bui @surrey.ac.uk
"""
import os, sys, torch
import argparse
from pathlib import Path
import numpy as np
import pickle
import pytorch_lightning as pl
from torchvision import transforms
import argparse
from ldm.util import instantiate_from_config
from omegaconf import OmegaConf
from PIL import Image
from tools.augment_imagenetc import RandomImagenetC
from io import BytesIO
from tools.helpers import welcome_message
from tools.ecc import BCH, RSC
import streamlit as st
from streamlit.source_util import (
page_icon_and_name,
calc_md5,
get_pages,
_on_pages_changed
)
model_names = ['UNet']
SECRET_LEN = 100
def delete_page(main_script_path_str, page_name):
current_pages = get_pages(main_script_path_str)
for key, value in current_pages.items():
print(value['page_name'])
if value['page_name'] == page_name:
del current_pages[key]
break
else:
pass
_on_pages_changed.send()
def add_page(main_script_path_str, page_name):
pages = get_pages(main_script_path_str)
main_script_path = Path(main_script_path_str)
pages_dir = main_script_path.parent / "pages"
# st.write(list(pages_dir.glob("*.py"))+list(main_script_path.parent.glob("*.py")))
script_path = [f for f in list(pages_dir.glob("*.py"))+list(main_script_path.parent.glob("*.py")) if f.name.find(page_name) != -1][0]
script_path_str = str(script_path.resolve())
pi, pn = page_icon_and_name(script_path)
psh = calc_md5(script_path_str)
pages[psh] = {
"page_script_hash": psh,
"page_name": pn,
"icon": pi,
"script_path": script_path_str,
}
_on_pages_changed.send()
def unormalize(x):
# convert x in range [-1, 1], (B,C,H,W), tensor to [0, 255], uint8, numpy, (B,H,W,C)
x = torch.clamp((x + 1) * 127.5, 0, 255).permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
return x
def to_bytes(x, mime):
x = Image.fromarray(x)
buf = BytesIO()
f = "JPEG" if mime == 'image/jpeg' else "PNG"
x.save(buf, format=f)
byte_im = buf.getvalue()
return byte_im
def load_UNet(args):
print('args: ', args)
# # crop safe model
# config_file = '/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_tform2/configs/-project.yaml'
# weight_file = '/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_tform2/checkpoints/epoch=000060-step=000189999.ckpt'
# # resized crop safe model
# config_file = '/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/configs/-project.yaml'
# weight_file = '/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/checkpoints/epoch=000070-step=000219999.ckpt'
config_file = args.config_file
weight_file = args.weight_file
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if weight_file.startswith('http'): # download from url
weight_dir = Path('./weights')
weight_dir.mkdir(exist_ok=True)
weight_path = weight_dir / weight_file.split('/')[-1]
config_path = weight_dir / config_file.split('/')[-1]
if not weight_path.exists():
import wget
print(f'Downloading {weight_file}...')
with st.spinner("Downloading model... this may take awhile!"):
wget.download(weight_file, str(weight_path))
wget.download(config_file, str(config_path))
weight_file = str(weight_path)
config_file = str(config_path)
config = OmegaConf.load(config_file).model
secret_len = config.params.secret_len
assert SECRET_LEN == secret_len
model = instantiate_from_config(config)
state_dict = torch.load(weight_file, map_location=torch.device('cpu'))
if 'global_step' in state_dict:
print(f'Global step: {state_dict["global_step"]}, epoch: {state_dict["epoch"]}')
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
misses, ignores = model.load_state_dict(state_dict, strict=False)
print(f'Missed keys: {misses}\nIgnore keys: {ignores}')
model = model.to(device)
model.eval()
return model
def embed_secret(model_name, model, cover, tform, secret):
if model_name == 'UNet':
w, h = cover.size
with torch.no_grad():
im = tform(cover).unsqueeze(0).cuda() # 1, 3, 256, 256
stego, _ = model(im, secret) # 1, 3, 256, 256
res = (stego.clamp(-1,1) - im) # (1,3,256,256) residual
res = torch.nn.functional.interpolate(res, (h,w), mode='bilinear')
res = res.permute(0,2,3,1).cpu().numpy() # (1,256,256,3)
stego_uint8 = np.clip(res[0] + np.array(cover)/127.5-1., -1,1)*127.5+127.5 # (256, 256, 3), ndarray, uint8
stego_uint8 = stego_uint8.astype(np.uint8)
else:
raise NotImplementedError
return stego_uint8
def identity(x):
return x
def decode_secret(model_name, model, im, tform):
if model_name in ['RoSteALS', 'UNet']:
with torch.no_grad():
im = tform(im).unsqueeze(0).cuda() # 1, 3, 256, 256
secret_pred = (model.decoder(im) > 0).cpu().numpy() # 1, 100
else:
raise NotImplementedError
return secret_pred
@st.cache_resource
def load_model(model_name, _args):
if model_name == 'UNet':
tform_emb = transforms.Compose([
transforms.Resize((256,256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
tform_det = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
model = load_UNet(_args)
else:
raise NotImplementedError
return model, tform_emb, tform_det
@st.cache_resource
def load_ecc(ecc_name):
if ecc_name == 'BCH':
# ecc = BCH(285, 10, SECRET_LEN, verbose=True)
ecc = BCH(payload_len= SECRET_LEN, verbose=True)
elif ecc_name == 'RSC':
ecc = RSC(data_bytes=16, ecc_bytes=4, verbose=True)
return ecc
class Resize(object):
def __init__(self, size=None) -> None:
self.size = size
def __call__(self, x, size=None):
if isinstance(x, np.ndarray):
x = Image.fromarray(x)
new_size = size if size is not None else self.size
if min(x.size) > min(new_size): # downsample
x = x.resize(new_size, Image.LANCZOS)
else: # upsample
x = x.resize(new_size, Image.BILINEAR)
x = np.array(x)
return x
def parse_st_args():
# usage: streamlit run app.py -- --arg1 val1 --arg2 val2
parser = argparse.ArgumentParser()
parser.add_argument('--weight', default='/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/checkpoints/epoch=000070-step=000219999.ckpt')
parser.add_argument('--config', default='/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/configs/-project.yaml')
# parser.add_argument('--cpu', action='store_true')
args = parser.parse_args()
return args
def app(args):
# delete_page('Embed_Secret', 'Extract_Secret')
st.title('Watermarking Demo')
# setup model
model_name = st.selectbox("Choose the model", model_names)
model, tform_emb, tform_det = load_model(model_name, args)
display_width = 300
# ecc
ecc = load_ecc('BCH')
assert ecc.get_total_len() == SECRET_LEN
# setup st
st.subheader("Input")
image_file = st.file_uploader("Upload an image", type=["png","jpg","jpeg"])
if image_file is not None:
print('Image: ', image_file.name)
ext = image_file.name.split('.')[-1]
im = Image.open(image_file).convert('RGB')
size0 = im.size
st.image(im, width=display_width)
secret_text = st.text_input(f'Input the secret (max {ecc.data_len} chars)', 'A secret')
assert len(secret_text) <= ecc.data_len
# embed
st.subheader("Embed results")
status = st.empty()
prep = transforms.Compose([
transforms.Resize((256,256)),
transforms.CenterCrop((224,224))
])
if image_file is not None and secret_text is not None:
secret = ecc.encode_text([secret_text]) # (1, len)
secret = torch.from_numpy(secret).float().cuda()
# im = tform(im).unsqueeze(0).cuda() # (1,3,H,W)
stego = embed_secret(model_name, model, im, tform_emb, secret)
st.image(stego, width=display_width)
# download button
mime='image/jpeg' if ext=='jpg' else f'image/{ext}'
stego_bytes = to_bytes(stego, mime)
st.download_button(label='Download image', data=stego_bytes, file_name=f'stego.{ext}', mime=mime)
# verify secret
stego_processed = prep(Image.fromarray(stego))
secret_pred = decode_secret(model_name, model, stego_processed, tform_det)
bit_acc = (secret_pred == secret.cpu().numpy()).mean()
secret_pred = ecc.decode_text(secret_pred)[0]
status.markdown('**Secret recovery check:** ' + secret_pred, unsafe_allow_html=True)
status.markdown('**Bit accuracy:** ' + str(bit_acc), unsafe_allow_html=True)
if __name__ == '__main__':
args = parse_st_args()
app(args)