File size: 9,603 Bytes
6142a25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04acf84
 
6142a25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e97a15c
6142a25
 
 
 
 
 
 
 
 
 
 
90921aa
6142a25
 
 
 
 
17b1745
6142a25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17b1745
6142a25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90921aa
6142a25
 
90921aa
6142a25
 
 
90921aa
6142a25
90921aa
 
 
 
6142a25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90921aa
6142a25
 
90921aa
6142a25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17b1745
6142a25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
streamlit app demo
how to run:
streamlit run app.py --server.port 8501

@author: Tu Bui @surrey.ac.uk
"""
import os, sys, torch 
import argparse
from pathlib import Path
import numpy as np
import pickle
import pytorch_lightning as pl
from torchvision import transforms
import argparse
from ldm.util import instantiate_from_config
from omegaconf import OmegaConf
from PIL import Image
from tools.augment_imagenetc import RandomImagenetC
from io import BytesIO
from tools.helpers import welcome_message
from tools.ecc import BCH, RSC

import streamlit as st
from streamlit.source_util import (
    page_icon_and_name, 
    calc_md5, 
    get_pages,
    _on_pages_changed
)

model_names = ['UNet']


def delete_page(main_script_path_str, page_name):

    current_pages = get_pages(main_script_path_str)

    for key, value in current_pages.items():
        print(value['page_name'])
        if value['page_name'] == page_name:
            del current_pages[key]
            break
        else:
            pass
    _on_pages_changed.send()


def add_page(main_script_path_str, page_name):
    
    pages = get_pages(main_script_path_str)
    main_script_path = Path(main_script_path_str)
    pages_dir = main_script_path.parent / "pages"
    # st.write(list(pages_dir.glob("*.py"))+list(main_script_path.parent.glob("*.py")))
    script_path = [f for f in list(pages_dir.glob("*.py"))+list(main_script_path.parent.glob("*.py")) if f.name.find(page_name) != -1][0]
    script_path_str = str(script_path.resolve())
    pi, pn = page_icon_and_name(script_path)
    psh = calc_md5(script_path_str)
    pages[psh] = {
        "page_script_hash": psh,
        "page_name": pn,
        "icon": pi,
        "script_path": script_path_str,
    }
    _on_pages_changed.send()

def unormalize(x):
    # convert x in range [-1, 1], (B,C,H,W), tensor to [0, 255], uint8, numpy, (B,H,W,C)
    x = torch.clamp((x + 1) * 127.5, 0, 255).permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
    return x

def to_bytes(x, mime):
    x = Image.fromarray(x)
    buf = BytesIO()
    f = "JPEG" if mime == 'image/jpeg' else "PNG"
    x.save(buf, format=f)
    byte_im = buf.getvalue()
    return byte_im


def load_UNet(args):
    print('args: ', args)
    # # crop safe model
    # config_file = '/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_tform2/configs/-project.yaml'
    # weight_file = '/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_tform2/checkpoints/epoch=000060-step=000189999.ckpt'

    # # resized crop safe model
    # config_file = '/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/configs/-project.yaml'
    # weight_file = '/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/checkpoints/epoch=000070-step=000219999.ckpt' 

    config_file = args.config
    weight_file = args.weight
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if weight_file.startswith('http'):  # download from url
        weight_dir = Path('./weights')
        weight_dir.mkdir(exist_ok=True)
        weight_path = weight_dir / weight_file.split('/')[-1]
        config_path = weight_dir / config_file.split('/')[-1]
        if not weight_path.exists():
            import wget
            print(f'Downloading {weight_file}...')
            with st.spinner("Downloading model... this may take awhile!"):
                wget.download(weight_file, str(weight_path))
                wget.download(config_file, str(config_path))
        weight_file = str(weight_path)
        config_file = str(config_path)

    config = OmegaConf.load(config_file).model
    secret_len = config.params.secret_len
    print(f'Secret length: {secret_len}')
    model = instantiate_from_config(config)
    state_dict = torch.load(weight_file, map_location=torch.device('cpu'))
    if 'global_step' in state_dict:
        print(f'Global step: {state_dict["global_step"]}, epoch: {state_dict["epoch"]}')

    if 'state_dict' in state_dict:
        state_dict = state_dict['state_dict']
    misses, ignores = model.load_state_dict(state_dict, strict=False)
    print(f'Missed keys: {misses}\nIgnore keys: {ignores}')
    model = model.to(device)
    model.eval()
    return model, secret_len

def embed_secret(model_name, model, cover, tform, secret):
    if model_name == 'UNet':
        w, h = cover.size
        with torch.no_grad():
            im = tform(cover).unsqueeze(0).to(model.device)  # 1, 3, 256, 256
            stego, _ = model(im, secret)  # 1, 3, 256, 256
            res = (stego.clamp(-1,1) - im)  # (1,3,256,256) residual
            res = torch.nn.functional.interpolate(res, (h,w), mode='bilinear')
            res = res.permute(0,2,3,1).cpu().numpy()  # (1,256,256,3)
            stego_uint8 = np.clip(res[0] + np.array(cover)/127.5-1., -1,1)*127.5+127.5  # (256, 256, 3), ndarray, uint8
            stego_uint8 = stego_uint8.astype(np.uint8)
    else:
        raise NotImplementedError
    return stego_uint8

def identity(x):
    return x

def decode_secret(model_name, model, im, tform):
    if model_name in ['RoSteALS', 'UNet']:
        with torch.no_grad():
            im = tform(im).unsqueeze(0).to(model.device)  # 1, 3, 256, 256
            secret_pred = (model.decoder(im) > 0).cpu().numpy()  # 1, 100
    else:
        raise NotImplementedError
    return secret_pred


@st.cache_resource
def load_model(model_name, _args):
    if model_name == 'UNet':
        tform_emb = transforms.Compose([
            transforms.Resize((256,256)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        tform_det = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        model, secret_len = load_UNet(_args)
    else:
        raise NotImplementedError
    return model, tform_emb, tform_det, secret_len


@st.cache_resource
def load_ecc(ecc_name, secret_len):
    if ecc_name == 'BCH':
        if secret_len == 160:
            ecc = BCH(285, 10, secret_len, verbose=True)
        elif secret_len == 100:
            ecc = BCH(137, 5, payload_len= secret_len, verbose=True)
    elif ecc_name == 'RSC':
        ecc = RSC(data_bytes=16, ecc_bytes=4, verbose=True)
    return ecc


class Resize(object):
    def __init__(self, size=None) -> None:
        self.size = size
    def __call__(self, x, size=None):
        if isinstance(x, np.ndarray):
            x = Image.fromarray(x)
        new_size = size if size is not None else self.size
        if min(x.size) > min(new_size):  # downsample
            x = x.resize(new_size, Image.LANCZOS)
        else:  # upsample
            x = x.resize(new_size, Image.BILINEAR)
        x = np.array(x)
        return x 


def parse_st_args():
    # usage: streamlit run app.py -- --arg1 val1 --arg2 val2
    parser = argparse.ArgumentParser()
    parser.add_argument('--weight', default='/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/checkpoints/epoch=000070-step=000219999.ckpt')
    parser.add_argument('--config', default='/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/configs/-project.yaml')
    # parser.add_argument('--cpu', action='store_true')
    args = parser.parse_args()
    return args


def app(args):
    # delete_page('Embed_Secret', 'Extract_Secret')
    st.title('Watermarking Demo')
    # setup model
    model_name = st.selectbox("Choose the model", model_names)
    model, tform_emb, tform_det, secret_len = load_model(model_name, args)
    display_width = 300
    # ecc
    ecc = load_ecc('BCH', secret_len)

    # setup st
    st.subheader("Input")
    image_file = st.file_uploader("Upload an image", type=["png","jpg","jpeg"])
    if image_file is not None:
        print('Image: ', image_file.name)
        ext = image_file.name.split('.')[-1]
        im = Image.open(image_file).convert('RGB')
        size0 = im.size
        st.image(im, width=display_width)
    secret_text = st.text_input(f'Input the secret (max {ecc.data_len} chars)', 'A secret')
    assert len(secret_text) <= ecc.data_len

    # embed
    st.subheader("Embed results")
    status = st.empty()
    prep = transforms.Compose([
        transforms.Resize((256,256)),
        transforms.CenterCrop((224,224))
    ])
    if image_file is not None and secret_text is not None:
        secret = ecc.encode_text([secret_text])  # (1, len)
        secret = torch.from_numpy(secret).float().to(model.device)
        # im = tform(im).unsqueeze(0).cuda()  # (1,3,H,W)
        stego = embed_secret(model_name, model, im, tform_emb, secret)
        st.image(stego, width=display_width)

        # download button
        mime='image/jpeg' if ext=='jpg' else f'image/{ext}'
        stego_bytes = to_bytes(stego, mime)
        st.download_button(label='Download image', data=stego_bytes, file_name=f'stego.{ext}', mime=mime)

        # verify secret
        stego_processed = prep(Image.fromarray(stego))
        secret_pred = decode_secret(model_name, model, stego_processed, tform_det)
        bit_acc = (secret_pred == secret.cpu().numpy()).mean()
        secret_pred = ecc.decode_text(secret_pred)[0]
        status.markdown('**Secret recovery check:** ' + secret_pred, unsafe_allow_html=True)
        status.markdown('**Bit accuracy:** ' + str(bit_acc), unsafe_allow_html=True)

if __name__ == '__main__':
    args = parse_st_args()
    app(args)