stable-signature-bzh / detect_torchscript.py
Vivien Chappelier
detector
91f4aea
raw
history blame contribute delete
No virus
1.25 kB
import torch
import torchvision.transforms as transforms
from PIL import Image
import sys
import numpy as np
from scipy.special import betainc
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
msg_decoder_path = sys.argv[3]
img_path = sys.argv[1]
key = int(sys.argv[2])
transform_imnet = transforms.Compose([
transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])
])
img = Image.open(sys.argv[1]).convert("RGB").resize((256, 256), Image.BICUBIC)
img = transform_imnet(img).unsqueeze(0).to(device)
print("img.min", img.min())
print("img.max", img.max())
print("img.shape", img.shape)
msg_decoder = torch.jit.load(msg_decoder_path).to(device)
msg_decoder.eval()
with torch.no_grad():
dec = msg_decoder(img)[0].cpu().numpy()
#print("dec = ", dec)
print("dec = ", dec.shape)
msg = np.random.default_rng(seed=key).standard_normal(256)
msg = msg / np.sqrt(np.dot(msg, msg))
print("dec.dec", dec.dot(dec))
print("msg.msg", msg.dot(msg))
print("dec.msg", dec.dot(msg))
cos_angle = dec.dot(msg)
pfa = betainc((256 - 1) * 0.5, 0.5, 1 - cos_angle*cos_angle)
print("pfa = ", pfa)