OzoneAsai's picture
Update app.py
0218651
raw
history blame
879 Bytes
import gradio as gr
import torch
from PIL import Image
from transformers import AutoModelForImageClassification, ViTImageProcessor
# Load model
model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection")
processor = ViTImageProcessor.from_pretrained('Falconsai/nsfw_image_detection')
# Function to make predictions
def classify_image(img):
with torch.no_grad():
inputs = processor(images=img, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_label = logits.argmax(-1).item()
return model.config.id2label[predicted_label]
# Gradio Interface
iface = gr.Interface(
fn=classify_image,
inputs="image",
outputs="text",
live=True,
interpretation="default"
)
# Launch the Gradio Interface
iface.launch()
#gr.load("models/Falconsai/nsfw_image_detection").launch()