yasserrmd commited on
Commit
a960bc2
1 Parent(s): 23dca80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -2
app.py CHANGED
@@ -9,7 +9,10 @@ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=
9
  pipe.load_lora_weights("Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch", weight_name="FLUX-dev-lora-children-simple-sketch.safetensors")
10
  pipe.fuse_lora(lora_scale=1.5)
11
  pipe.to("cuda")
12
- classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
 
 
 
13
 
14
  # Define the function to generate the sketch
15
  @spaces.GPU
@@ -18,7 +21,18 @@ def generate_sketch(prompt, num_inference_steps, guidance_scale):
18
  num_inference_steps=num_inference_steps,
19
  guidance_scale=guidance_scale,
20
  ).images[0]
21
- print(classifier(image))
 
 
 
 
 
 
 
 
 
 
 
22
  image_path = "generated_sketch.png"
23
 
24
  image.save(image_path)
 
9
  pipe.load_lora_weights("Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch", weight_name="FLUX-dev-lora-children-simple-sketch.safetensors")
10
  pipe.fuse_lora(lora_scale=1.5)
11
  pipe.to("cuda")
12
+
13
+ # Load the NSFW classifier
14
+ classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection",device=torch.device('cuda'))
15
+ NSFW_THRESHOLD = 0.85
16
 
17
  # Define the function to generate the sketch
18
  @spaces.GPU
 
21
  num_inference_steps=num_inference_steps,
22
  guidance_scale=guidance_scale,
23
  ).images[0]
24
+
25
+
26
+ # Classify the image for NSFW content
27
+ classification = classifier(image)
28
+
29
+ print(classification)
30
+
31
+ # Check the classification results
32
+ for result in classification:
33
+ if result['label'] == 'nsfw' and result['score'] > NSFW_THRESHOLD:
34
+ return "Inappropriate content detected. Please try another prompt."
35
+
36
  image_path = "generated_sketch.png"
37
 
38
  image.save(image_path)