itsTomLie's picture
Update app.py
b8124d1 verified
raw
history blame
850 Bytes
import gradio as gr
import numpy as np
import os
from PIL import Image
from transformers import pipeline
def predict_image(image):
pipe = pipeline("image-classification", model="rizvandwiki/gender-classification")
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype('uint8'))
elif isinstance(image, str):
image = Image.open(image)
result = pipe(image)
label = result[0]['label']
confidence = result[0]['score']
return label, confidence
example_images = [
os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
]
interface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="numpy", label="Upload an Image"),
outputs=[gr.Textbox(label="Prediction"), gr.Textbox(label="Confidence")],
examples=example_images
)
interface.launch()