Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import BeitFeatureExtractor, BeitForImageClassification | |
from PIL import Image | |
import requests | |
import numpy as np | |
# Load the pre-trained BEiT model and feature extractor | |
feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-large-patch16-512') | |
model = BeitForImageClassification.from_pretrained('microsoft/beit-large-patch16-512') | |
def classify_image(input_image): | |
image = Image.fromarray(input_image.astype('uint8')) | |
inputs = feature_extractor(images=image, return_tensors="pt") | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predicted_class_idx = logits.argmax(-1).item() | |
predicted_class = model.config.id2label[predicted_class_idx] | |
return {"Predicted Class": predicted_class} | |
iface = gr.Interface( | |
fn=classify_image, | |
inputs=gr.inputs.Image(type="numpy"), # Specify input type as numpy array | |
outputs="json", | |
live=True, | |
title="BEiT Classification", | |
description="Upload an image and you will get a description" | |
) | |
if __name__ == "__main__": | |
iface.launch() | |