test / pages /Extract_Secret.py
Tu Bui
add 160bit support
90921aa
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
streamlit app demo
how to run:
streamlit run app.py --server.port 8501
@author: Tu Bui @surrey.ac.uk
"""
import os, sys, torch
import inspect
cdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
sys.path.insert(1, os.path.join(cdir, '../'))
import argparse
from pathlib import Path
import numpy as np
import pickle
import pytorch_lightning as pl
from torchvision import transforms
import argparse
from ldm.util import instantiate_from_config
from omegaconf import OmegaConf
from PIL import Image
from tools.augment_imagenetc import RandomImagenetC
from cldm.transformations2 import TransformNet
from io import BytesIO
from tools.helpers import welcome_message
from tools.ecc import BCH, RSC
import streamlit as st
from Embed_Secret import parse_st_args, load_ecc, load_model, decode_secret, to_bytes, model_names
def app(args):
st.title('Watermarking Demo')
# setup model
model_name = st.selectbox("Choose the model", model_names)
model, tform_emb, tform_det, secret_len = load_model(model_name, args)
display_width = 300
ecc = load_ecc('BCH', secret_len)
noise = TransformNet(p=1.0, crop_mode='resized_crop')
noise_names = noise.optional_names
# setup st
st.subheader("Input")
image_file = None
image_file = st.file_uploader("Upload stego image", type=["png","jpg","jpeg"])
if image_file is not None:
im = Image.open(image_file).convert('RGB')
ext = image_file.name.split('.')[-1]
st.image(im, width=display_width)
# add crop
st.subheader("Corruptions")
crop_button = st.button('Regenerate Crop/Flip/Resize', key='crop')
if image_file is not None:
im_crop = noise.apply_transform_on_pil_image(im, 'Fixed Augment')
if crop_button:
im_crop = noise.apply_transform_on_pil_image(im, 'Fixed Augment')
# st.image(im_crop, width=display_width)
# add noise source 1
corrupt_method1 = st.selectbox("Choose noise source #1", ['None'] + noise_names, key='noise1')
if image_file is not None:
if corrupt_method1=='None':
im_noise1 = im_crop
else:
im_noise1 = noise.apply_transform_on_pil_image(im_crop, corrupt_method1)
# st.image(im_noise1, width=display_width)
# add noise source 2
corrupt_method2 = st.selectbox("Choose noise source #2", ['None'] + noise_names, key='noise2')
if image_file is not None:
if corrupt_method2=='None':
im_noise2 = im_noise1
else:
im_noise2 = noise.apply_transform_on_pil_image(im_noise1, corrupt_method2)
st.subheader("Output")
if image_file is not None:
st.image(im_noise2, width=display_width)
mime='image/jpeg' if ext=='jpg' else f'image/{ext}'
im_noise2_bytes = to_bytes(np.uint8(im_noise2), mime)
st.download_button(label='Download image', data=im_noise2_bytes, file_name=f'corrupted.{ext}', mime=mime)
# prediction
st.subheader('Extract Secret From Output')
status = st.empty()
if image_file is not None:
secret_pred = decode_secret(model_name, model, im_noise2, tform_det)
secret_decoded = ecc.decode_text(secret_pred)[0]
status.markdown(f'Predicted secret: **{secret_decoded}**', unsafe_allow_html=True)
# bit acc
st.subheader('Accuracy')
secret_text = st.text_input('Input groundtruth secret')
bit_acc_status = st.empty()
if image_file is not None and secret_text:
secret = ecc.encode_text([secret_text]) # (1, 100)
bit_acc = (secret_pred == secret).mean()
# bit_acc_status.markdown('**Bit Accuracy**: {:.2f}%'.format(bit_acc*100), unsafe_allow_html=True)
word_acc = int(secret_decoded == secret_text)
bit_acc_status.markdown(f'Bit Accuracy: **{bit_acc*100:.2f}%**<br />Word Accuracy: **{word_acc}**', unsafe_allow_html=True)
if __name__ == '__main__':
args = parse_st_args()
app(args)