import gradio as gr import numpy as np from PIL import Image from transformers import AutoImageProcessor, AutoModelForImageClassification model_names = [ "0-ma/swin-geometric-shapes-tiny", "0-ma/mobilenet-v2-geometric-shapes", "0-ma/focalnet-geometric-shapes-tiny", "0-ma/efficientnet-b2-geometric-shapes", "0-ma/beit-geometric-shapes-base", "0-ma/mit-b0-geometric-shapes", "0-ma/vit-geometric-shapes-base", "0-ma/resnet-geometric-shapes", "0-ma/vit-geometric-shapes-tiny", ] example_images = [ 'example/1_None.jpg', 'example/2_Circle.jpg', 'example/3_Triangle.jpg', 'example/4_Square.jpg', 'example/5_Pentagone.jpg', 'example/6_Hexagone.jpg' ] labels = [example.split("_")[1].split(".")[0] for example in example_images] feature_extractors = {model_name: AutoImageProcessor.from_pretrained(model_name) for model_name in model_names} classification_models = {model_name: AutoModelForImageClassification.from_pretrained(model_name) for model_name in model_names} def predict(image, selected_model): if image is None: return None feature_extractor = feature_extractors[selected_model] model = classification_models[selected_model] inputs = feature_extractor(images=[image], return_tensors="pt") logits = model(**inputs)['logits'].cpu().detach().numpy()[0] logits_positive = logits logits_positive[logits < 0] = 0 logits_positive = logits_positive/np.sum(logits_positive) confidences = {} for i in range(len(labels)): if logits[i] > 0: confidences[labels[i]] = float(logits_positive[i]) return confidences title = "Geometric Shape Classifier" description = "Select a model and upload an image to classify geometric shapes." with gr.Blocks() as demo: gr.Markdown(f"# {title}") gr.Markdown(description) with gr.Row(): model_dropdown = gr.Dropdown(choices=model_names, label="Select Model", value=model_names[0]) image_input = gr.Image(type="pil") # Move the Examples section here, before the output gr.Examples( examples=example_images, inputs=image_input, label="Click on an example image to test", ) # Output section output = gr.Label(label="Classification Result") # Event handlers def classify(img, model): if img is not None: return predict(img, model) return None image_input.change(fn=classify, inputs=[image_input, model_dropdown], outputs=output) model_dropdown.change(fn=classify, inputs=[image_input, model_dropdown], outputs=output) demo.launch()