#!/usr/bin/env python # -*- coding: utf-8 -*- """ streamlit app demo how to run: streamlit run app.py --server.port 8501 @author: Tu Bui @surrey.ac.uk """ import os, sys, torch import argparse from pathlib import Path import numpy as np import pickle import pytorch_lightning as pl from torchvision import transforms import argparse from ldm.util import instantiate_from_config from omegaconf import OmegaConf from PIL import Image from tools.augment_imagenetc import RandomImagenetC from io import BytesIO from tools.helpers import welcome_message from tools.ecc import BCH, RSC import streamlit as st from streamlit.source_util import ( page_icon_and_name, calc_md5, get_pages, _on_pages_changed ) model_names = ['UNet'] def delete_page(main_script_path_str, page_name): current_pages = get_pages(main_script_path_str) for key, value in current_pages.items(): print(value['page_name']) if value['page_name'] == page_name: del current_pages[key] break else: pass _on_pages_changed.send() def add_page(main_script_path_str, page_name): pages = get_pages(main_script_path_str) main_script_path = Path(main_script_path_str) pages_dir = main_script_path.parent / "pages" # st.write(list(pages_dir.glob("*.py"))+list(main_script_path.parent.glob("*.py"))) script_path = [f for f in list(pages_dir.glob("*.py"))+list(main_script_path.parent.glob("*.py")) if f.name.find(page_name) != -1][0] script_path_str = str(script_path.resolve()) pi, pn = page_icon_and_name(script_path) psh = calc_md5(script_path_str) pages[psh] = { "page_script_hash": psh, "page_name": pn, "icon": pi, "script_path": script_path_str, } _on_pages_changed.send() def unormalize(x): # convert x in range [-1, 1], (B,C,H,W), tensor to [0, 255], uint8, numpy, (B,H,W,C) x = torch.clamp((x + 1) * 127.5, 0, 255).permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) return x def to_bytes(x, mime): x = Image.fromarray(x) buf = BytesIO() f = "JPEG" if mime == 'image/jpeg' else "PNG" x.save(buf, format=f) byte_im = buf.getvalue() return byte_im def load_UNet(args): print('args: ', args) # # crop safe model # config_file = '/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_tform2/configs/-project.yaml' # weight_file = '/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_tform2/checkpoints/epoch=000060-step=000189999.ckpt' # # resized crop safe model # config_file = '/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/configs/-project.yaml' # weight_file = '/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/checkpoints/epoch=000070-step=000219999.ckpt' config_file = args.config weight_file = args.weight device = 'cuda' if torch.cuda.is_available() else 'cpu' if weight_file.startswith('http'): # download from url weight_dir = Path('./weights') weight_dir.mkdir(exist_ok=True) weight_path = weight_dir / weight_file.split('/')[-1] config_path = weight_dir / config_file.split('/')[-1] if not weight_path.exists(): import wget print(f'Downloading {weight_file}...') with st.spinner("Downloading model... this may take awhile!"): wget.download(weight_file, str(weight_path)) wget.download(config_file, str(config_path)) weight_file = str(weight_path) config_file = str(config_path) config = OmegaConf.load(config_file).model secret_len = config.params.secret_len print(f'Secret length: {secret_len}') model = instantiate_from_config(config) state_dict = torch.load(weight_file, map_location=torch.device('cpu')) if 'global_step' in state_dict: print(f'Global step: {state_dict["global_step"]}, epoch: {state_dict["epoch"]}') if 'state_dict' in state_dict: state_dict = state_dict['state_dict'] misses, ignores = model.load_state_dict(state_dict, strict=False) print(f'Missed keys: {misses}\nIgnore keys: {ignores}') model = model.to(device) model.eval() return model, secret_len def embed_secret(model_name, model, cover, tform, secret): if model_name == 'UNet': w, h = cover.size with torch.no_grad(): im = tform(cover).unsqueeze(0).to(model.device) # 1, 3, 256, 256 stego, _ = model(im, secret) # 1, 3, 256, 256 res = (stego.clamp(-1,1) - im) # (1,3,256,256) residual res = torch.nn.functional.interpolate(res, (h,w), mode='bilinear') res = res.permute(0,2,3,1).cpu().numpy() # (1,256,256,3) stego_uint8 = np.clip(res[0] + np.array(cover)/127.5-1., -1,1)*127.5+127.5 # (256, 256, 3), ndarray, uint8 stego_uint8 = stego_uint8.astype(np.uint8) else: raise NotImplementedError return stego_uint8 def identity(x): return x def decode_secret(model_name, model, im, tform): if model_name in ['RoSteALS', 'UNet']: with torch.no_grad(): im = tform(im).unsqueeze(0).to(model.device) # 1, 3, 256, 256 secret_pred = (model.decoder(im) > 0).cpu().numpy() # 1, 100 else: raise NotImplementedError return secret_pred @st.cache_resource def load_model(model_name, _args): if model_name == 'UNet': tform_emb = transforms.Compose([ transforms.Resize((256,256)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) tform_det = transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) model, secret_len = load_UNet(_args) else: raise NotImplementedError return model, tform_emb, tform_det, secret_len @st.cache_resource def load_ecc(ecc_name, secret_len): if ecc_name == 'BCH': if secret_len == 160: ecc = BCH(285, 10, secret_len, verbose=True) elif secret_len == 100: ecc = BCH(137, 5, payload_len= secret_len, verbose=True) elif ecc_name == 'RSC': ecc = RSC(data_bytes=16, ecc_bytes=4, verbose=True) return ecc class Resize(object): def __init__(self, size=None) -> None: self.size = size def __call__(self, x, size=None): if isinstance(x, np.ndarray): x = Image.fromarray(x) new_size = size if size is not None else self.size if min(x.size) > min(new_size): # downsample x = x.resize(new_size, Image.LANCZOS) else: # upsample x = x.resize(new_size, Image.BILINEAR) x = np.array(x) return x def parse_st_args(): # usage: streamlit run app.py -- --arg1 val1 --arg2 val2 parser = argparse.ArgumentParser() parser.add_argument('--weight', default='/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/checkpoints/epoch=000070-step=000219999.ckpt') parser.add_argument('--config', default='/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/configs/-project.yaml') # parser.add_argument('--cpu', action='store_true') args = parser.parse_args() return args def app(args): # delete_page('Embed_Secret', 'Extract_Secret') st.title('Watermarking Demo') # setup model model_name = st.selectbox("Choose the model", model_names) model, tform_emb, tform_det, secret_len = load_model(model_name, args) display_width = 300 # ecc ecc = load_ecc('BCH', secret_len) # setup st st.subheader("Input") image_file = st.file_uploader("Upload an image", type=["png","jpg","jpeg"]) if image_file is not None: print('Image: ', image_file.name) ext = image_file.name.split('.')[-1] im = Image.open(image_file).convert('RGB') size0 = im.size st.image(im, width=display_width) secret_text = st.text_input(f'Input the secret (max {ecc.data_len} chars)', 'A secret') assert len(secret_text) <= ecc.data_len # embed st.subheader("Embed results") status = st.empty() prep = transforms.Compose([ transforms.Resize((256,256)), transforms.CenterCrop((224,224)) ]) if image_file is not None and secret_text is not None: secret = ecc.encode_text([secret_text]) # (1, len) secret = torch.from_numpy(secret).float().to(model.device) # im = tform(im).unsqueeze(0).cuda() # (1,3,H,W) stego = embed_secret(model_name, model, im, tform_emb, secret) st.image(stego, width=display_width) # download button mime='image/jpeg' if ext=='jpg' else f'image/{ext}' stego_bytes = to_bytes(stego, mime) st.download_button(label='Download image', data=stego_bytes, file_name=f'stego.{ext}', mime=mime) # verify secret stego_processed = prep(Image.fromarray(stego)) secret_pred = decode_secret(model_name, model, stego_processed, tform_det) bit_acc = (secret_pred == secret.cpu().numpy()).mean() secret_pred = ecc.decode_text(secret_pred)[0] status.markdown('**Secret recovery check:** ' + secret_pred, unsafe_allow_html=True) status.markdown('**Bit accuracy:** ' + str(bit_acc), unsafe_allow_html=True) if __name__ == '__main__': args = parse_st_args() app(args)