File size: 2,975 Bytes
4bee283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
import gradio.inputs as grinputs
import gradio.outputs as groutputs

import numpy as np

import torch
import torch.nn as nn
from torchvision import models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(0)
np.random.seed(0)

FPR = 1e-6
carrier = np.random.randn(size=(1, 2048))


def build_backbone(path, name='resnet50'):
    """ Builds a pretrained ResNet-50 backbone. """
    model = getattr(models, name)(pretrained=True)
    model.head = nn.Identity()
    model.fc = nn.Identity()
    checkpoint = torch.load(path, map_location=device)
    state_dict = checkpoint
    for ckpt_key in ['state_dict', 'model_state_dict', 'teacher']:
        if ckpt_key in checkpoint:
            state_dict = checkpoint[ckpt_key]
    state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
    state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
    msg = model.load_state_dict(state_dict, strict=False)
    return model

def get_linear_layer(weight, bias):
    """ Creates a layer that performs feature whitening or centering """
    dim_out, dim_in = weight.shape
    layer = nn.Linear(dim_in, dim_out)
    layer.weight = nn.Parameter(weight)
    layer.bias = nn.Parameter(bias)
    return layer

def load_normalization_layer(path):
    """
    Loads the normalization layer from a checkpoint and returns the layer.
    """
    checkpoint = torch.load(path, map_location=device)
    if 'whitening' in path or 'out' in path:
        D = checkpoint['weight'].shape[1]
        weight = torch.nn.Parameter(D*checkpoint['weight'])
        bias = torch.nn.Parameter(D*checkpoint['bias'])
    else:
        weight = checkpoint['weight']
        bias = checkpoint['bias']
    return get_linear_layer(weight, bias).to(device, non_blocking=True)

class NormLayerWrapper(nn.Module):
    """
    Wraps backbone model and normalization layer
    """
    def __init__(self, backbone, head):
        super(NormLayerWrapper, self).__init__()
        backbone.eval(), head.eval()
        self.backbone = backbone
        self.head = head

    def forward(self, x):
        output = self.backbone(x)
        return self.head(output)

backbone = build_backbone(path='dino_r50.pth')
normlayer = load_normalization_layer(path='out2048.pth')
model = NormLayerWrapper(backbone, normlayer)

def encode(image):
    return image

def decode(image):
    return 'decoded'

def on_submit(image, mode):
    print('{} mode'.format(mode))
    if mode=='Encode':
        return encode(image), 'Successfully encoded'
    else:
        return image, decode(image)

iface = gr.Interface(
    fn=on_submit, 
    inputs=[
        grinputs.Image(), 
        grinputs.Radio(['Encode', 'Decode'], label="Encode or Decode mode")], 
    outputs=[
        groutputs.Image(label='Watermarked image'), 
        groutputs.Textbox(label='Information')],
    allow_screenshot=False, 
    allow_flagging="auto",
    )
iface.launch()