karwanjiru commited on
Commit
66b2cfe
·
1 Parent(s): 5c8ab1d

image moderation

Browse files
Files changed (1) hide show
  1. app.py +57 -14
app.py CHANGED
@@ -1,13 +1,46 @@
 
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
  from diffusers import DiffusionPipeline
5
- import torch
6
  from huggingface_hub import InferenceClient
7
  import requests
8
- from PIL import Image
9
  from io import BytesIO
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # Device configuration
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
@@ -24,6 +57,9 @@ else:
24
  MAX_SEED = np.iinfo(np.int32).max
25
  MAX_IMAGE_SIZE = 1024
26
 
 
 
 
27
  # Inference function for generating images
28
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
29
  if randomize_seed:
@@ -43,16 +79,6 @@ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance
43
 
44
  return image
45
 
46
- # Examples for the text-to-image generation
47
- examples = [
48
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
49
- "An astronaut riding a green horse",
50
- "A delicious ceviche cheesecake slice",
51
- ]
52
-
53
- # Initialize the InferenceClient
54
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
55
-
56
  # Respond function for the chatbot
57
  def respond(message, history, system_message, max_tokens, temperature, top_p):
58
  messages = [{"role": "system", "content": system_message}]
@@ -186,5 +212,22 @@ with gr.Blocks(css=css) as demo:
186
  image_moderation_result = gr.Textbox(label="Image Moderation Result")
187
  moderate_image_button.click(moderate_image, uploaded_image, image_moderation_result)
188
 
189
- if __name__ == "__main__":
190
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import torch
4
+ from torchvision import transforms
5
+ from transformers import AutoProcessor, FocalNetForImageClassification
6
  import gradio as gr
7
  import numpy as np
8
  import random
9
  from diffusers import DiffusionPipeline
 
10
  from huggingface_hub import InferenceClient
11
  import requests
 
12
  from io import BytesIO
13
 
14
+ # Paths and model setup
15
+ image_folder = "path_to_your_image_folder" # Specify the path to your image folder
16
+ model_path = "MichalMlodawski/nsfw-image-detection-large"
17
+
18
+ # List of jpg files in the folder
19
+ jpg_files = [file for file in os.listdir(image_folder) if file.lower().endswith(".jpg")]
20
+
21
+ if not jpg_files:
22
+ print("🚫 No jpg files found in folder:", image_folder)
23
+ exit()
24
+
25
+ # Load the model and feature extractor
26
+ feature_extractor = AutoProcessor.from_pretrained(model_path)
27
+ model = FocalNetForImageClassification.from_pretrained(model_path)
28
+ model.eval()
29
+
30
+ # Image transformations
31
+ transform = transforms.Compose([
32
+ transforms.Resize((512, 512)),
33
+ transforms.ToTensor(),
34
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
35
+ ])
36
+
37
+ # Mapping from model labels to NSFW categories
38
+ label_to_category = {
39
+ "LABEL_0": "Safe",
40
+ "LABEL_1": "Questionable",
41
+ "LABEL_2": "Unsafe"
42
+ }
43
+
44
  # Device configuration
45
  device = "cuda" if torch.cuda.is_available() else "cpu"
46
 
 
57
  MAX_SEED = np.iinfo(np.int32).max
58
  MAX_IMAGE_SIZE = 1024
59
 
60
+ # Initialize the InferenceClient
61
+ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
62
+
63
  # Inference function for generating images
64
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
65
  if randomize_seed:
 
79
 
80
  return image
81
 
 
 
 
 
 
 
 
 
 
 
82
  # Respond function for the chatbot
83
  def respond(message, history, system_message, max_tokens, temperature, top_p):
84
  messages = [{"role": "system", "content": system_message}]
 
212
  image_moderation_result = gr.Textbox(label="Image Moderation Result")
213
  moderate_image_button.click(moderate_image, uploaded_image, image_moderation_result)
214
 
215
+ with gr.TabItem("NSFW Classification"):
216
+ selected_image = gr.Image(type="pil", label="Upload Image for NSFW Classification")
217
+ classify_button = gr.Button("Classify Image")
218
+ classification_result = gr.Textbox(label="Classification Result")
219
+
220
+ def classify_nsfw(image):
221
+ image_tensor = transform(image).unsqueeze(0)
222
+ inputs = feature_extractor(images=image, return_tensors="pt")
223
+ with torch.no_grad():
224
+ outputs = model(**inputs)
225
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
226
+ confidence, predicted = torch.max(probabilities, 1)
227
+ label = model.config.id2label[predicted.item()]
228
+ category = label_to_category.get(label, "Unknown")
229
+ return f"Label: {label}, Category: {category}, Confidence: {confidence.item() * 100:.2f}%"
230
+
231
+ classify_button.click(classify_nsfw, selected_image, classification_result)
232
+
233
+ demo.launch()