rynmurdock commited on
Commit
5c43323
1 Parent(s): 32bd2b2

added safety checker

Browse files
app.py CHANGED
@@ -27,6 +27,8 @@ from transformers import CLIPVisionModelWithProjection
27
  from huggingface_hub import hf_hub_download
28
  from safetensors.torch import load_file
29
 
 
 
30
  prompt_list = [p for p in list(set(
31
  pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
32
 
@@ -56,6 +58,8 @@ pipe.to(device=DEVICE)
56
  # TODO put back
57
  @spaces.GPU
58
  def compile_em():
 
 
59
  pipe.unet = torch.compile(pipe.unet)
60
  pipe.vae = torch.compile(pipe.vae, mode='reduce-overhead')
61
  autoencoder.model.forward = torch.compile(autoencoder.model.forward, backend='inductor', dynamic=True)
@@ -160,6 +164,11 @@ def predict(
160
  im_emb, _ = pipe.encode_image(
161
  image, DEVICE, 1, output_hidden_state
162
  )
 
 
 
 
 
163
  return image, im_emb.to('cpu')
164
 
165
 
@@ -245,10 +254,10 @@ def next_image(embs, img_embs, ys, calibrate_prompts):
245
  image, img_emb = predict(prompt, im_emb=img_emb)
246
  img_embs.append(img_emb)
247
 
248
- #if len(embs) > 20:
249
- # embs.pop(0)
250
- # img_embs.pop(0)
251
- # ys.pop(0)
252
  return image, embs, img_embs, ys, calibrate_prompts
253
 
254
 
@@ -274,7 +283,7 @@ def start(_, embs, img_embs, ys, calibrate_prompts):
274
  ]
275
 
276
 
277
- def choose(choice, embs, img_embs, ys, calibrate_prompts):
278
  if choice == 'Like (L)':
279
  choice = 1
280
  elif choice == 'Neither (Space)':
@@ -284,6 +293,12 @@ def choose(choice, embs, img_embs, ys, calibrate_prompts):
284
  return img, embs, img_embs, ys, calibrate_prompts
285
  else:
286
  choice = 0
 
 
 
 
 
 
287
  ys.append(choice)
288
  img, embs, img_embs, ys, calibrate_prompts = next_image(embs, img_embs, ys, calibrate_prompts)
289
  return img, embs, img_embs, ys, calibrate_prompts
@@ -363,17 +378,17 @@ with gr.Blocks(css=css, head=js_head) as demo:
363
  b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like")
364
  b1.click(
365
  choose,
366
- [b1, embs, img_embs, ys, calibrate_prompts],
367
  [img, embs, img_embs, ys, calibrate_prompts]
368
  )
369
  b2.click(
370
  choose,
371
- [b2, embs, img_embs, ys, calibrate_prompts],
372
  [img, embs, img_embs, ys, calibrate_prompts]
373
  )
374
  b3.click(
375
  choose,
376
- [b3, embs, img_embs, ys, calibrate_prompts],
377
  [img, embs, img_embs, ys, calibrate_prompts]
378
  )
379
  with gr.Row():
 
27
  from huggingface_hub import hf_hub_download
28
  from safetensors.torch import load_file
29
 
30
+ from safety_checker_improved import maybe_nsfw
31
+
32
  prompt_list = [p for p in list(set(
33
  pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
34
 
 
58
  # TODO put back
59
  @spaces.GPU
60
  def compile_em():
61
+ # TODO Compile
62
+ return None
63
  pipe.unet = torch.compile(pipe.unet)
64
  pipe.vae = torch.compile(pipe.vae, mode='reduce-overhead')
65
  autoencoder.model.forward = torch.compile(autoencoder.model.forward, backend='inductor', dynamic=True)
 
164
  im_emb, _ = pipe.encode_image(
165
  image, DEVICE, 1, output_hidden_state
166
  )
167
+
168
+ nsfw = maybe_nsfw(image)
169
+ if nsfw:
170
+ return None, im_emb.to('cpu')
171
+
172
  return image, im_emb.to('cpu')
173
 
174
 
 
254
  image, img_emb = predict(prompt, im_emb=img_emb)
255
  img_embs.append(img_emb)
256
 
257
+ if len(embs) > 100:
258
+ embs.pop(0)
259
+ img_embs.pop(0)
260
+ ys.pop(0)
261
  return image, embs, img_embs, ys, calibrate_prompts
262
 
263
 
 
283
  ]
284
 
285
 
286
+ def choose(img, choice, embs, img_embs, ys, calibrate_prompts):
287
  if choice == 'Like (L)':
288
  choice = 1
289
  elif choice == 'Neither (Space)':
 
293
  return img, embs, img_embs, ys, calibrate_prompts
294
  else:
295
  choice = 0
296
+
297
+ print(img, 'img')
298
+ if img is None:
299
+ print('NSFW -- choice is disliked')
300
+ choice = 0
301
+
302
  ys.append(choice)
303
  img, embs, img_embs, ys, calibrate_prompts = next_image(embs, img_embs, ys, calibrate_prompts)
304
  return img, embs, img_embs, ys, calibrate_prompts
 
378
  b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like")
379
  b1.click(
380
  choose,
381
+ [img, b1, embs, img_embs, ys, calibrate_prompts],
382
  [img, embs, img_embs, ys, calibrate_prompts]
383
  )
384
  b2.click(
385
  choose,
386
+ [img, b2, embs, img_embs, ys, calibrate_prompts],
387
  [img, embs, img_embs, ys, calibrate_prompts]
388
  )
389
  b3.click(
390
  choose,
391
+ [img, b3, embs, img_embs, ys, calibrate_prompts],
392
  [img, embs, img_embs, ys, calibrate_prompts]
393
  )
394
  with gr.Row():
nsfweffnetv2-b02-3epochs.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91422f388d1632c1af21b3d787b4f6c1a8e6114f600162d392b0bf285ff8a433
3
+ size 71027272
safety_checker_improved.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # TODO required tensorflow==2.14 for me
3
+ # weights from https://github.com/LAION-AI/safety-pipeline/tree/main
4
+ from PIL import Image
5
+ import tensorflow_hub as hub
6
+ import tensorflow
7
+ import numpy as np
8
+ import sys
9
+ sys.path.append('/home/ryn_mote/Misc/generative_recommender/gradio_video/automl/efficientnetv2/')
10
+ import tensorflow as tf
11
+ from tensorflow.keras import mixed_precision
12
+ physical_devices = tf.config.list_physical_devices('GPU')
13
+
14
+ tf.config.experimental.set_memory_growth(
15
+ physical_devices[0], True
16
+ )
17
+
18
+ model = tf.keras.models.load_model('nsfweffnetv2-b02-3epochs.h5',custom_objects={"KerasLayer":hub.KerasLayer})
19
+ # "The image classifier had been trained on 682550 images from the 5 classes "Drawing" (39026), "Hentai" (28134), "Neutral" (369507), "Porn" (207969) & "Sexy" (37914).
20
+ # ... we created a manually inspected test set that consists of 4900 samples, that contains images & their captions."
21
+
22
+ # Run prediction
23
+ def maybe_nsfw(pil_image):
24
+ # Run prediction
25
+ imm = tensorflow.image.resize(np.array(pil_image)[:, :, :3], (260, 260))
26
+ imm = (imm / 255)
27
+ pred = model(tensorflow.expand_dims(imm, 0)).numpy()
28
+ probs = tensorflow.math.softmax(pred[0]).numpy()
29
+ print(probs)
30
+ if all([i < .3 for i in probs[[1, 3, 4]]]):
31
+ return False
32
+ return True
33
+
34
+ # pre-initializing prediction
35
+ maybe_nsfw(Image. new("RGB", (260, 260), 255))
36
+ model.load_weights('nsfweffnetv2-b02-3epochs.h5', by_name=True, )
37
+
38
+
39
+
40
+
41
+
42
+
43
+
44
+
45
+