QuestApp / app.py
ryaalbr's picture
Update app.py
29b65ac
raw history blame
No virus
14.3 kB
import gradio as gr
from datasets import load_dataset
import random
import numpy as np
from transformers import CLIPProcessor, CLIPModel
from os import environ
import clip
import pickle
import requests
import torch
import os
from huggingface_hub import hf_hub_download
from torch import nn
import torch.nn.functional as nnf
import sys
from typing import Tuple, List, Union, Optional
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
N = type(None)
V = np.array
ARRAY = np.ndarray
ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
VS = Union[Tuple[V, ...], List[V]]
VN = Union[V, N]
VNS = Union[VS, N]
T = torch.Tensor
TS = Union[Tuple[T, ...], List[T]]
TN = Optional[T]
TNS = Union[Tuple[TN, ...], List[TN]]
TSN = Optional[TS]
TA = Union[T, ARRAY]
D = torch.device
CPU = torch.device('cpu')
device = "cuda" if torch.cuda.is_available() else "cpu"
# # Load the pre-trained model and processor
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model = clip_model.to(device)
#orig_clip_model, orig_clip_processor = clip.load("ViT-B/32", device=device, jit=False)
# Load the Unsplash dataset
dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train") # all 25K images are in train split
dataset_size = len(dataset)
# Load gpt and modifed weights for captions
gpt = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
conceptual_weight = hf_hub_download(repo_id="akhaliq/CLIP-prefix-captioning-conceptual-weights", filename="conceptual_weights.pt")
coco_weight = hf_hub_download(repo_id="akhaliq/CLIP-prefix-captioning-COCO-weights", filename="coco_weights.pt")
height = 256 # height for resizing images
def predict(image, labels):
with torch.no_grad():
inputs = clip_processor(text=[f"a photo of {c}" for c in labels], images=image, return_tensors="pt", padding=True).to(device)
outputs = clip_model(**inputs)
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
probs = logits_per_image.softmax(dim=1).cpu().numpy() # we can take the softmax to get the label probabilities
return {k: float(v) for k, v in zip(labels, probs[0])}
# def predict2(image, labels):
# image = orig_clip_processor(image).unsqueeze(0).to(device)
# text = clip.tokenize(labels).to(device)
# with torch.no_grad():
# image_features = orig_clip_model.encode_image(image)
# text_features = orig_clip_model.encode_text(text)
# logits_per_image, logits_per_text = orig_clip_model(image, text)
# probs = logits_per_image.softmax(dim=-1).cpu().numpy()
# return {k: float(v) for k, v in zip(labels, probs[0])}
def rand_image():
n = dataset.num_rows
r = random.randrange(0,n)
return dataset[r]["photo_image_url"] + f"?h={height}" # Unsplash allows dynamic requests, including size of image
def set_labels(text):
return text.split(",")
# get_caption = gr.load("ryaalbr/caption", src="spaces", hf_token=environ["api_key"])
# def generate_text(image, model_name):
# return get_caption(image, model_name)
class MLP(nn.Module):
def forward(self, x: T) -> T:
return self.model(x)
def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
super(MLP, self).__init__()
layers = []
for i in range(len(sizes) -1):
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
if i < len(sizes) - 2:
layers.append(act())
self.model = nn.Sequential(*layers)
class ClipCaptionModel(nn.Module):
def get_dummy_token(self, batch_size: int, device: D) -> T:
return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
def forward(self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None):
embedding_text = self.gpt.transformer.wte(tokens)
prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
if labels is not None:
dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
labels = torch.cat((dummy_token, tokens), dim=1)
out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
return out
def __init__(self, prefix_length: int, prefix_size: int = 512):
super(ClipCaptionModel, self).__init__()
self.prefix_length = prefix_length
self.gpt = gpt
self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
if prefix_length > 10: # not enough memory
self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
else:
self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length))
#clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
def get_caption(img,model_name):
prefix_length = 10
model = ClipCaptionModel(prefix_length)
if model_name == "COCO":
model_path = coco_weight
else:
model_path = conceptual_weight
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model = model.eval()
model = model.to(device)
input = clip_processor(images=img, return_tensors="pt").to(device)
with torch.no_grad():
prefix = clip_model.get_image_features(**input)
# image = preprocess(img).unsqueeze(0).to(device)
# with torch.no_grad():
# prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
output = model.gpt.generate(inputs_embeds=prefix_embed,
num_beams=1,
do_sample=False,
num_return_sequences=1,
no_repeat_ngram_size=1,
max_new_tokens = 67,
pad_token_id = tokenizer.eos_token_id,
eos_token_id = tokenizer.encode('.')[0],
renormalize_logits = True)
generated_text_prefix = tokenizer.decode(output[0], skip_special_tokens=True)
return generated_text_prefix[:-1] if generated_text_prefix[-1] == "." else generated_text_prefix #remove period at end if present
# get_images = gr.load("ryaalbr/ImageSearch", src="spaces", hf_token=environ["api_key"])
# def search_images(text):
# return get_images(text, api_name="images")
emb_filename = 'unsplash-25k-photos-embeddings-indexes.pkl'
with open(emb_filename, 'rb') as emb:
id2url, img_names, img_emb = pickle.load(emb)
def search(search_query):
with torch.no_grad():
# Encode and normalize the description using CLIP (HF CLIP)
inputs = clip_processor(text=search_query, images=None, return_tensors="pt", padding=True).to(device)
text_encoded = clip_model.get_text_features(**inputs)
# # Encode and normalize the description using CLIP (original CLIP)
# text_encoded = orig_clip_model.encode_text(clip.tokenize(search_query))
# text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
# Retrieve the description vector
text_features = text_encoded.cpu().numpy()
# Compute the similarity between the descrption and each photo using the Cosine similarity
similarities = (text_features @ img_emb.T).squeeze(0)
# Sort the photos by their similarity score
best_photos = similarities.argsort()[::-1]
best_photos = best_photos[:15]
#best_photos = sorted(zip(similarities, range(img_emb.shape[0])), key=lambda x: x[0], reverse=True)
best_photo_ids = img_names[best_photos]
imgs = []
# Iterate over the top 5 results
for id in best_photo_ids:
id, _ = id.split('.')
url = id2url.get(id, "")
if url == "": continue
img = url + "?h=512"
# r = requests.get(url + "?w=512", stream=True)
# img = Image.open(r.raw)
#credits = f'Photo by <a href="https://unsplash.com/@{photo["photographer_username"]}?utm_source=NaturalLanguageImageSearch&utm_medium=referral">{photo["photographer_first_name"]} {photo["photographer_last_name"]}</a> on <a href="https://unsplash.com/?utm_source=NaturalLanguageImageSearch&utm_medium=referral">Unsplash</a>'
imgs.append(img)
#display(HTML(f'Photo by <a href="https://unsplash.com/@{photo["photographer_username"]}?utm_source=NaturalLanguageImageSearch&utm_medium=referral">{photo["photographer_first_name"]} {photo["photographer_last_name"]}</a> on <a href="https://unsplash.com/?utm_source=NaturalLanguageImageSearch&utm_medium=referral">Unsplash</a>'))
if len(imgs) == 5: break
return imgs
with gr.Blocks() as demo:
with gr.Tab("Classification"):
labels = gr.State([]) # creates hidden component that can store a value and can be used as input/output; here, initial value is an empty list
instructions = """## Instructions:
1. Enter list of labels separated by commas (or select one of the examples below)
2. Click **Get Random Image** to grab a random image from dataset
3. Click **Classify Image** to analyze current image against the labels (including after changing labels)
"""
gr.Markdown(instructions)
with gr.Row(variant="compact"):
label_text = gr.Textbox(show_label=False, placeholder="Enter classification labels").style(container=False)
#submit_btn = gr.Button("Submit").style(full_width=False)
gr.Examples(["spring, summer, fall, winter",
"mountain, city, beach, ocean, desert, forest, valley",
"red, blue, green, white, black, purple, brown",
"person, animal, landscape, something else",
"day, night, dawn, dusk"], inputs=label_text)
with gr.Row():
with gr.Column(variant="panel"):
im = gr.Image(interactive=False).style(height=height)
with gr.Row():
get_btn = gr.Button("Get Random Image").style(full_width=False)
class_btn = gr.Button("Classify Image").style(full_width=False)
cf = gr.Label()
#submit_btn.click(fn=set_labels, inputs=label_text)
label_text.change(fn=set_labels, inputs=label_text, outputs=labels) # parse list if changed
label_text.blur(fn=set_labels, inputs=label_text, outputs=labels) # parse list if focus is moved elsewhere; ensures that list is fully parsed before classification
label_text.submit(fn=set_labels, inputs=label_text, outputs=labels) # parse list if user hits enter; ensures that list is fully parsed before classification
get_btn.click(fn=rand_image, outputs=im)
#im.change(predict, inputs=[im, labels], outputs=cf)
class_btn.click(predict, inputs=[im, labels], outputs=cf)
gr.HTML(f"Dataset: <a href='https://github.com/unsplash/datasets' target='_blank'>Unsplash Lite</a>; Number of Images: {dataset_size}")
with gr.Tab("Captioning"):
instructions = """## Instructions:
1. Click **Get Random Image** to grab a random image from dataset
1. Click **Create Caption** to generate a caption for the image (usually takes 5-10s but could be over 60s)
1. Different models can be selected:
* **COCO** generally produces more straight-forward captions, but it is a smaller dataset and therefore struggles to recognize certain objects
* **Conceptual Captions** is a much larger dataset but sometimes produces results that resemble social media posts rather than captions
"""
gr.Markdown(instructions)
with gr.Row():
with gr.Column(variant="panel"):
im_cap = gr.Image(interactive=False).style(height=height)
model_name = gr.Radio(choices=["COCO","Conceptual Captions"], type="value", value="COCO", label="Model").style(container=True, item_container = False)
with gr.Row():
get_btn_cap = gr.Button("Get Random Image").style(full_width=False)
caption_btn = gr.Button("Create Caption").style(full_width=False)
caption = gr.Textbox(label='Caption', elem_classes="caption-text")
get_btn_cap.click(fn=rand_image, outputs=im_cap)
#im_cap.change(generate_text, inputs=im_cap, outputs=caption)
caption_btn.click(get_caption, inputs=[im_cap, model_name], outputs=caption)
gr.HTML(f"Dataset: <a href='https://github.com/unsplash/datasets' target='_blank'>Unsplash Lite</a>; Number of Images: {dataset_size}")
with gr.Tab("Search"):
instructions = """## Instructions:
1. Enter a search query (or select one of the examples below)
2. Click **Find Images** to find images that match the query (top 5 are shown in order from left to right)
3. Keep in mind that the dataset contains mostly nature-focused images"""
gr.Markdown(instructions)
with gr.Column(variant="panel"):
desc = gr.Textbox(show_label=False, placeholder="Enter description").style(container=False)
gr.Examples(["someone holding flowers",
"someone holding pink flowers",
"red fruit in a person's hands",
"an aerial view of forest",
"a waterfall in Iceland with a rainbow"
], inputs=desc)
search_btn = gr.Button("Find Images").style(full_width=False)
gallery = gr.Gallery(show_label=False).style(grid=(2,2,3,5))
search_btn.click(search,inputs=desc, outputs=gallery, postprocess=False)
gr.HTML(f"Dataset: <a href='https://github.com/unsplash/datasets' target='_blank'>Unsplash Lite</a>; Number of Images: {dataset_size}")
demo.queue(concurrency_count=3)
demo.launch()