Vivien Chappelier commited on
Commit
71f9973
1 Parent(s): b747147

add second class in export to make it compatible with inference api

Browse files
Files changed (2) hide show
  1. calibration.safetensors +1 -1
  2. export_detector.py +20 -6
calibration.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f0eac4b0be8eb96b2a6fc7596727954ebe600951b3c4719b456948200d75f58e
3
  size 1999934
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ebfb5a2481c315bca61b2348b0268d621a351d48cf5e07785d45d5877f67ebf3
3
  size 1999934
export_detector.py CHANGED
@@ -10,9 +10,6 @@ from PIL import Image
10
  # read logits file
11
  data=(np.asarray([float(x) for x in open(sys.argv[1]).readlines()]))
12
 
13
- # negate for consistency with "1" = "watermarked"
14
- data = -data
15
-
16
  # sort and convert to safetensors format
17
  data = np.sort(data)
18
  data_min = data.min()
@@ -47,12 +44,29 @@ image_processor = BlipImageProcessor(do_resize=True,
47
 
48
  detector = AutoModelForImageClassification.from_pretrained(detector_path)
49
 
50
- # make it output "1" for "watermarked"
51
  detector.eval()
52
  with torch.no_grad():
53
- detector.classifier[1].weight.copy_(-detector.classifier[1].weight)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- #detector.push_to_hub("imatag/stable-signature-bzh-detector-resnet18")
56
  #image_processor.push_to_hub("imatag/stable-signature-bzh-detector-resnet18")
57
  examples = ['examples/not_watermarked.png', 'examples/watermarked.png']
58
 
 
10
  # read logits file
11
  data=(np.asarray([float(x) for x in open(sys.argv[1]).readlines()]))
12
 
 
 
 
13
  # sort and convert to safetensors format
14
  data = np.sort(data)
15
  data_min = data.min()
 
44
 
45
  detector = AutoModelForImageClassification.from_pretrained(detector_path)
46
 
47
+ # make it output 2 labels, with label "0" for "watermarked"
48
  detector.eval()
49
  with torch.no_grad():
50
+ w0 = detector.classifier[1].weight
51
+ b0 = detector.classifier[1].bias
52
+ fdim = w0.shape[1]
53
+ w = torch.nn.Parameter(torch.zeros((2, fdim), dtype=w0.dtype))
54
+ w[0, :] = w0
55
+ w[1, :] = -w0
56
+ detector.classifier[1].weight = w
57
+ b = torch.nn.Parameter(torch.zeros((2,), dtype=b0.dtype))
58
+ b[0] = b0
59
+ b[1] = -b0
60
+ detector.classifier[1].bias = b
61
+ labels = ["no watermark detected", "watermarked"]
62
+ label2id, id2label = dict(), dict()
63
+ for i, label in enumerate(labels):
64
+ label2id[label] = str(i)
65
+ id2label[str(i)] = label
66
+ detector.config.id2label=id2label
67
+ detector.config.label2id=label2id
68
 
69
+ detector.push_to_hub("imatag/stable-signature-bzh-detector-resnet18")
70
  #image_processor.push_to_hub("imatag/stable-signature-bzh-detector-resnet18")
71
  examples = ['examples/not_watermarked.png', 'examples/watermarked.png']
72