File size: 4,031 Bytes
6142a25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90921aa
6142a25
 
04acf84
6142a25
 
 
90921aa
6142a25
90921aa
6142a25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04acf84
6142a25
04acf84
6142a25
04acf84
6142a25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04acf84
 
6142a25
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
streamlit app demo
how to run:
streamlit run app.py --server.port 8501

@author: Tu Bui @surrey.ac.uk
"""
import os, sys, torch 
import inspect
cdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
sys.path.insert(1, os.path.join(cdir, '../'))
import argparse
from pathlib import Path
import numpy as np
import pickle
import pytorch_lightning as pl
from torchvision import transforms
import argparse
from ldm.util import instantiate_from_config
from omegaconf import OmegaConf
from PIL import Image
from tools.augment_imagenetc import RandomImagenetC
from cldm.transformations2 import TransformNet
from io import BytesIO
from tools.helpers import welcome_message
from tools.ecc import BCH, RSC
import streamlit as st
from Embed_Secret import parse_st_args, load_ecc, load_model, decode_secret, to_bytes, model_names


def app(args):
    st.title('Watermarking Demo')
    # setup model
    model_name = st.selectbox("Choose the model", model_names)
    model, tform_emb, tform_det, secret_len = load_model(model_name, args)
    display_width = 300
    ecc = load_ecc('BCH', secret_len)
    noise = TransformNet(p=1.0, crop_mode='resized_crop')
    noise_names = noise.optional_names

    # setup st
    st.subheader("Input")
    image_file = None
    image_file = st.file_uploader("Upload stego image", type=["png","jpg","jpeg"])
    if image_file is not None:
        im = Image.open(image_file).convert('RGB')
        ext = image_file.name.split('.')[-1]
        st.image(im, width=display_width)
        
    
    # add crop
    st.subheader("Corruptions")
    crop_button = st.button('Regenerate Crop/Flip/Resize', key='crop')
    if image_file is not None:
        im_crop = noise.apply_transform_on_pil_image(im, 'Fixed Augment')
        if crop_button:
            im_crop = noise.apply_transform_on_pil_image(im, 'Fixed Augment')
        # st.image(im_crop, width=display_width)
    
    # add noise source 1
    corrupt_method1 = st.selectbox("Choose noise source #1", ['None'] + noise_names, key='noise1')
    if image_file is not None:
        if corrupt_method1=='None':
            im_noise1 = im_crop
        else:
            im_noise1 = noise.apply_transform_on_pil_image(im_crop, corrupt_method1)
            # st.image(im_noise1, width=display_width)

    # add noise source 2
    corrupt_method2 = st.selectbox("Choose noise source #2", ['None'] + noise_names, key='noise2')
    if image_file is not None:
        if corrupt_method2=='None':
            im_noise2 = im_noise1
        else:
            im_noise2 = noise.apply_transform_on_pil_image(im_noise1, corrupt_method2)
        
    st.subheader("Output")
    if image_file is not None:
        st.image(im_noise2, width=display_width)
        mime='image/jpeg' if ext=='jpg' else f'image/{ext}'
        im_noise2_bytes = to_bytes(np.uint8(im_noise2), mime)
        st.download_button(label='Download image', data=im_noise2_bytes, file_name=f'corrupted.{ext}', mime=mime)

    # prediction
    st.subheader('Extract Secret From Output')
    status = st.empty()
    if image_file is not None:
        secret_pred = decode_secret(model_name, model, im_noise2, tform_det)
        secret_decoded = ecc.decode_text(secret_pred)[0]
        status.markdown(f'Predicted secret: **{secret_decoded}**', unsafe_allow_html=True)
    
    # bit acc
    st.subheader('Accuracy')
    secret_text = st.text_input('Input groundtruth secret')
    bit_acc_status = st.empty()
    if image_file is not None and secret_text:
        secret = ecc.encode_text([secret_text])  # (1, 100)
        bit_acc = (secret_pred == secret).mean()
        # bit_acc_status.markdown('**Bit Accuracy**: {:.2f}%'.format(bit_acc*100), unsafe_allow_html=True)
        word_acc = int(secret_decoded == secret_text)
        bit_acc_status.markdown(f'Bit Accuracy: **{bit_acc*100:.2f}%**<br />Word Accuracy: **{word_acc}**', unsafe_allow_html=True)

if __name__ == '__main__':
    args = parse_st_args()
    app(args)