import gradio as gr from transformers import AutoProcessor, AutoModelForZeroShotImageClassification from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize from PIL import Image from datasets import load_dataset # Load your fine-tuned model and dataset processor = AutoProcessor.from_pretrained("DGurgurov/clip-vit-base-patch32-oxford-pets") model = AutoModelForZeroShotImageClassification.from_pretrained("DGurgurov/clip-vit-base-patch32-oxford-pets") # Load dataset to get labels dataset = load_dataset("pcuenq/oxford-pets") # Adjust dataset loading as per your setup labels = list(set(dataset['train']['label'])) label2id = {label: i for i, label in enumerate(labels)} id2label = {i: label for label, i in label2id.items()} # Function to classify image using CLIP model def classify_image(image): # Preprocess the image image = Image.fromarray(image) inputs = processor(text=labels, images=image, return_tensors="pt", padding=True) # Run inference outputs = model(**inputs) # Extract logits and apply softmax logits_per_image = outputs.logits_per_image # logits_per_image is a tensor with shape [1, num_labels] probs = logits_per_image[0].softmax(dim=0) # Take the softmax across the labels # Get predicted label id and score predicted_label_id = probs.argmax().item() predicted_label = id2label[predicted_label_id] return predicted_label # Gradio interface iface = gr.Interface( fn=classify_image, inputs=gr.Image(label="Upload a picture of an animal"), outputs=gr.Textbox(label="Predicted Animal"), title="Animal Classifier", description="CLIP-based model fine-tuned on Oxford Pets dataset to classify animals.", ) # Launch the Gradio interface iface.launch()