ruidanwang commited on
Commit
f0e0383
1 Parent(s): feb62a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -12
app.py CHANGED
@@ -1,24 +1,28 @@
1
- # prompt: gradio image 分类 not safe for work
2
  from PIL import Image
 
3
  import gradio as gr
4
- from transformers import pipeline
5
- # Load the image classification pipeline
6
- classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
 
 
7
 
8
  # Define a function to classify the image and return the results
9
  def classify_image(img):
10
- # Convert the Gradio image input to a PIL image
11
  pil_image = Image.fromarray(img.astype('uint8'), 'RGB')
12
- # Classify the image using the pipeline
13
- results = classifier(pil_image)
14
- # Format the results for display in Gradio
15
- formatted_results = {result['label']: result['score'] for result in results}
16
- return formatted_results
 
 
17
 
18
  # Create the Gradio interface
19
  image_input = gr.inputs.Image(shape=(256, 256))
20
- label_output = gr.outputs.Label(num_top_classes=3)
21
  interface = gr.Interface(fn=classify_image, inputs=image_input, outputs=label_output)
22
 
23
  # Launch the interface
24
- interface.launch()
 
1
+ import torch
2
  from PIL import Image
3
+ from transformers import AutoModelForImageClassification, ViTImageProcessor
4
  import gradio as gr
5
+
6
+ # Load the model and processor
7
+ model_name = "Falconsai/nsfw_image_detection"
8
+ model = AutoModelForImageClassification.from_pretrained(model_name)
9
+ processor = ViTImageProcessor.from_pretrained(model_name)
10
 
11
  # Define a function to classify the image and return the results
12
  def classify_image(img):
 
13
  pil_image = Image.fromarray(img.astype('uint8'), 'RGB')
14
+ inputs = processor(images=pil_image, return_tensors="pt")
15
+ with torch.no_grad():
16
+ outputs = model(**inputs)
17
+ logits = outputs.logits
18
+ predicted_label = logits.argmax(-1).item()
19
+ label = model.config.id2label[predicted_label]
20
+ return label
21
 
22
  # Create the Gradio interface
23
  image_input = gr.inputs.Image(shape=(256, 256))
24
+ label_output = gr.outputs.Label()
25
  interface = gr.Interface(fn=classify_image, inputs=image_input, outputs=label_output)
26
 
27
  # Launch the interface
28
+ interface.launch()