import gradio as gr from datasets import load_dataset import random import numpy as np from transformers import CLIPProcessor, CLIPModel from os import environ # Load the pre-trained model and processor model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") # Load the Unsplash dataset dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train") # all 25K images are in train split height = 256 # height for resizing images def predict(image, labels): inputs = processor(text=[f"a photo of {c}" for c in labels], images=image, return_tensors="pt", padding=True) outputs = model(**inputs) logits_per_image = outputs.logits_per_image # this is the image-text similarity score probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities return {k: float(v) for k, v in zip(labels, probs[0])} def rand_image(): n = dataset.num_rows r = random.randrange(0,n) return dataset[r]["photo_image_url"] + f"?h={height}" # Unsplash allows dynamic requests, including size of image def set_labels(text): return text.split(",") get_caption = gr.load("ryaalbr/caption", src="spaces", hf_token=environ["api_key"]) def generate_text(image, model_name): return get_caption(image, model_name) get_images = gr.load("ryaalbr/ImageSearch", src="spaces", hf_token=environ["api_key"]) def search_images(text): return get_images(text, api_name="images") with gr.Blocks() as demo: with gr.Tab("Zero-Shot Classification"): labels = gr.State([]) # creates hidden component that can store a value and can be used as input/output; here, initial value is an empty list instructions = """## Instructions: 1. Enter list of labels separated by commas (or select one of the examples below) 2. Click **Get Random Image** to grab a random image from dataset and analyze it against the labels 3. Click **Re-Classify Image** to re-run classification on current image after changing labels""" gr.Markdown(instructions) with gr.Row(variant="compact"): label_text = gr.Textbox(show_label=False, placeholder="Enter classification labels").style(container=False) #submit_btn = gr.Button("Submit").style(full_width=False) gr.Examples(["spring, summer, fall, winter", "mountain, city, beach, ocean, desert, forest, valley", "red, blue, green, white, black, purple, brown", "person, animal, landscape, something else", "day, night, dawn, dusk"], inputs=label_text) with gr.Row(): with gr.Column(variant="panel"): im = gr.Image(interactive=False).style(height=height) with gr.Row(): get_btn = gr.Button("Get Random Image").style(full_width=False) reclass_btn = gr.Button("Re-Classify Image").style(full_width=False) cf = gr.Label(), inputs=label_text) label_text.change(fn=set_labels, inputs=label_text, outputs=labels) # parse list if changed label_text.blur(fn=set_labels, inputs=label_text, outputs=labels) # parse list if focus is moved elsewhere; ensures that list is fully parsed before classification label_text.submit(fn=set_labels, inputs=label_text, outputs=labels) # parse list if user hits enter; ensures that list is fully parsed before classification, outputs=im) im.change(predict, inputs=[im, labels], outputs=cf), inputs=[im, labels], outputs=cf) with gr.Tab("Image Captioning"): with gr.Row(): with gr.Column(variant="panel"): im_cap = gr.Image(interactive=False, type='filepath').style(height=height) model_name = gr.Radio(choices=["COCO","Conceptual captions"], type="value", value="COCO", label="Model").style(container=True, item_container = False) with gr.Row(): get_btn_cap = gr.Button("Get Random Image").style(full_width=False) caption_btn = gr.Button("Create Caption").style(full_width=False) caption = gr.Textbox(label='Caption'), outputs=im_cap) #im_cap.change(generate_text, inputs=im_cap, outputs=caption), inputs=[im_cap, model_name], outputs=caption) with gr.Tab("Image Search"): with gr.Column(variant="panel"): desc = gr.Textbox(show_label=False, placeholder="Enter description").style(container=False) search_btn = gr.Button("Find Images").style(full_width=False) gallery = gr.Gallery(show_label=False).style(grid=(2,2,3,5)),inputs=desc, outputs=gallery) demo.launch()