import gradio as gr
from datasets import load_dataset
import random
import numpy as np
from transformers import CLIPProcessor, CLIPModel
from os import environ
import clip
import pickle
import requests
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
# # Load the pre-trained model and processor
# model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
orig_clip_model, orig_clip_processor = clip.load("ViT-B/32", device=device, jit=False)
# Load the Unsplash dataset
dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train") # all 25K images are in train split
height = 256 # height for resizing images
def predict(image, labels):
inputs = processor(text=[f"a photo of {c}" for c in labels], images=image, return_tensors="pt", padding=True)
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
return {k: float(v) for k, v in zip(labels, probs[0])}
def predict2(image, labels):
image = orig_clip_processor(img).unsqueeze(0).to(device)
text = clip.tokenize(labels).to(device)
with torch.no_grad():
image_features = orig_clip_model.encode_image(image)
text_features = orig_clip_model.encode_text(text)
logits_per_image, logits_per_text = orig_clip_model(image, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
return {k: float(v) for k, v in zip(labels, probs[0])}
def rand_image():
n = dataset.num_rows
r = random.randrange(0,n)
return dataset[r]["photo_image_url"] + f"?h={height}" # Unsplash allows dynamic requests, including size of image
def set_labels(text):
return text.split(",")
get_caption = gr.load("ryaalbr/caption", src="spaces", hf_token=environ["api_key"])
def generate_text(image, model_name):
return get_caption(image, model_name)
# get_images = gr.load("ryaalbr/ImageSearch", src="spaces", hf_token=environ["api_key"])
# def search_images(text):
# return get_images(text, api_name="images")
emb_filename = 'unsplash-25k-photos-embeddings-indexes.pkl'
with open(emb_filename, 'rb') as emb:
id2url, img_names, img_emb = pickle.load(emb)
def search(search_query):
with torch.no_grad():
# Encode and normalize the description using CLIP
text_encoded = orig_clip_model.encode_text(clip.tokenize(search_query))
text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
# Retrieve the description vector
text_features = text_encoded.cpu().numpy()
# Compute the similarity between the descrption and each photo using the Cosine similarity
similarities = (text_features @ img_emb.T).squeeze(0)
# Sort the photos by their similarity score
best_photos = similarities.argsort()[::-1]
best_photos = best_photos[:15]
#best_photos = sorted(zip(similarities, range(img_emb.shape[0])), key=lambda x: x[0], reverse=True)
best_photo_ids = img_names[best_photos]
imgs = []
# Iterate over the top 5 results
for id in best_photo_ids:
id, _ = id.split('.')
url = id2url.get(id, "")
if url == "": continue
img = url + "?h=512"
# r = requests.get(url + "?w=512", stream=True)
# img = Image.open(r.raw)
#credits = f'Photo by {photo["photographer_first_name"]} {photo["photographer_last_name"]} on Unsplash'
imgs.append(img)
#display(HTML(f'Photo by {photo["photographer_first_name"]} {photo["photographer_last_name"]} on Unsplash'))
if len(imgs) == 5: break
return imgs
with gr.Blocks() as demo:
with gr.Tab("Zero-Shot Classification"):
labels = gr.State([]) # creates hidden component that can store a value and can be used as input/output; here, initial value is an empty list
instructions = """## Instructions:
1. Enter list of labels separated by commas (or select one of the examples below)
2. Click **Get Random Image** to grab a random image from dataset and analyze it against the labels
3. Click **Re-Classify Image** to re-run classification on current image after changing labels"""
gr.Markdown(instructions)
with gr.Row(variant="compact"):
label_text = gr.Textbox(show_label=False, placeholder="Enter classification labels").style(container=False)
#submit_btn = gr.Button("Submit").style(full_width=False)
gr.Examples(["spring, summer, fall, winter",
"mountain, city, beach, ocean, desert, forest, valley",
"red, blue, green, white, black, purple, brown",
"person, animal, landscape, something else",
"day, night, dawn, dusk"], inputs=label_text)
with gr.Row():
with gr.Column(variant="panel"):
im = gr.Image(interactive=False).style(height=height)
with gr.Row():
get_btn = gr.Button("Get Random Image").style(full_width=False)
reclass_btn = gr.Button("Re-Classify Image").style(full_width=False)
cf = gr.Label()
#submit_btn.click(fn=set_labels, inputs=label_text)
label_text.change(fn=set_labels, inputs=label_text, outputs=labels) # parse list if changed
label_text.blur(fn=set_labels, inputs=label_text, outputs=labels) # parse list if focus is moved elsewhere; ensures that list is fully parsed before classification
label_text.submit(fn=set_labels, inputs=label_text, outputs=labels) # parse list if user hits enter; ensures that list is fully parsed before classification
get_btn.click(fn=rand_image, outputs=im)
im.change(predict2, inputs=[im, labels], outputs=cf)
reclass_btn.click(predict2, inputs=[im, labels], outputs=cf)
with gr.Tab("Image Captioning"):
with gr.Row():
with gr.Column(variant="panel"):
im_cap = gr.Image(interactive=False, type='filepath').style(height=height)
model_name = gr.Radio(choices=["COCO","Conceptual captions"], type="value", value="COCO", label="Model").style(container=True, item_container = False)
with gr.Row():
get_btn_cap = gr.Button("Get Random Image").style(full_width=False)
caption_btn = gr.Button("Create Caption").style(full_width=False)
caption = gr.Textbox(label='Caption')
get_btn_cap.click(fn=rand_image, outputs=im_cap)
#im_cap.change(generate_text, inputs=im_cap, outputs=caption)
caption_btn.click(generate_text, inputs=[im_cap, model_name], outputs=caption)
with gr.Tab("Image Search"):
with gr.Column(variant="panel"):
desc = gr.Textbox(show_label=False, placeholder="Enter description").style(container=False)
search_btn = gr.Button("Find Images").style(full_width=False)
gallery = gr.Gallery(show_label=False).style(grid=(2,2,3,5))
search_btn.click(search,inputs=desc, outputs=gallery)
demo.launch()