import gradio as gr from datasets import load_dataset import random import numpy as np from transformers import CLIPProcessor, CLIPModel # Load the pre-trained model and processor model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") # Load the Unsplash dataset dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train") # all 25K images are in train split height = 250 # height for resizing images def predict(image, labels): inputs = processor(text=[f"a photo of {c}" for c in labels], images=image, return_tensors="pt", padding=True) outputs = model(**inputs) logits_per_image = outputs.logits_per_image # this is the image-text similarity score probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities return {k: float(v) for k, v in zip(labels, probs[0])} def rand_image(): n = dataset.num_rows r = random.randrange(0,n) return dataset[r]["photo_image_url"] + f"?h={height}" # Unsplash allows dynamic requests, including size of image def set_labels(text): return text.split(",") with gr.Blocks() as demo: labels = gr.State([]) # creates hidden component that can store a value and can be used as input/output; here, initial value is an empty list instructions = """## Instructions: 1. Enter list of labels separated by commas (or select one of the examples below) 2. Click **Get Random Image** to grab a random image from dataset and analyze it against the labels 3. Click **Re-Classify Image** to re-run classification on current image after changing labels""" gr.Markdown(instructions) with gr.Row(variant="compact"): label_text = gr.Textbox(show_label=False, placeholder="Enter classification labels").style(container=False) #submit_btn = gr.Button("Submit").style(full_width=False) gr.Examples(["spring, summer, fall, winter", "mountain, city, beach, ocean, desert, forest, valley", "red, blue, green, white, black, purple, brown", "person, animal, landscape, something else", "day, night, dawn, dusk"], inputs=label_text) with gr.Row(): with gr.Column(variant="panel"): im = gr.Image(interactive=False).style(height=height) with gr.Row(): get_btn = gr.Button("Get Random Image").style(full_width=False) reclass_btn = gr.Button("Re-Classify Image").style(full_width=False) cf = gr.Label() #submit_btn.click(fn=set_labels, inputs=label_text) label_text.change(fn=set_labels, inputs=label_text, outputs=labels) # parse list if changed label_text.blur(fn=set_labels, inputs=label_text, outputs=labels) # parse list if focus is moved elsewhere; ensures that list is fully parsed before classification label_text.submit(fn=set_labels, inputs=label_text, outputs=labels) # parse list if user hits enter; ensures that list is fully parsed before classification get_btn.click(fn=rand_image, outputs=im) im.change(predict, inputs=[im, labels], outputs=cf) reclass_btn.click(predict, inputs=[im, labels], outputs=cf) demo.launch()