QuestApp / app.py
ryaalbr's picture
Update app.py
c042a4f
raw
history blame
No virus
4.21 kB
import gradio as gr
from datasets import load_dataset
import random
import numpy as np
from transformers import CLIPProcessor, CLIPModel
import os.environ
# Load the pre-trained model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Load the Unsplash dataset
dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train") # all 25K images are in train split
height = 250 # height for resizing images
def predict(image, labels):
inputs = processor(text=[f"a photo of {c}" for c in labels], images=image, return_tensors="pt", padding=True)
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
return {k: float(v) for k, v in zip(labels, probs[0])}
def rand_image():
n = dataset.num_rows
r = random.randrange(0,n)
return dataset[r]["photo_image_url"] # + f"?h={height}" # Unsplash allows dynamic requests, including size of image
def set_labels(text):
return text.split(",")
get_caption = gr.load("ryaalbr/caption", src="spaces", hf_token=os.environ("api_key"))
def generate_text(image):
return get_caption(image)
with gr.Blocks() as demo:
with gr.Tab("Zero-Shot Classification"):
labels = gr.State([]) # creates hidden component that can store a value and can be used as input/output; here, initial value is an empty list
instructions = """## Instructions:
1. Enter list of labels separated by commas (or select one of the examples below)
2. Click **Get Random Image** to grab a random image from dataset and analyze it against the labels
3. Click **Re-Classify Image** to re-run classification on current image after changing labels"""
gr.Markdown(instructions)
with gr.Row(variant="compact"):
label_text = gr.Textbox(show_label=False, placeholder="Enter classification labels").style(container=False)
#submit_btn = gr.Button("Submit").style(full_width=False)
gr.Examples(["spring, summer, fall, winter",
"mountain, city, beach, ocean, desert, forest, valley",
"red, blue, green, white, black, purple, brown",
"person, animal, landscape, something else",
"day, night, dawn, dusk"], inputs=label_text)
with gr.Row():
with gr.Column(variant="panel"):
im = gr.Image(interactive=False).style(height=height)
with gr.Row():
get_btn = gr.Button("Get Random Image").style(full_width=False)
reclass_btn = gr.Button("Re-Classify Image").style(full_width=False)
cf = gr.Label()
#submit_btn.click(fn=set_labels, inputs=label_text)
label_text.change(fn=set_labels, inputs=label_text, outputs=labels) # parse list if changed
label_text.blur(fn=set_labels, inputs=label_text, outputs=labels) # parse list if focus is moved elsewhere; ensures that list is fully parsed before classification
label_text.submit(fn=set_labels, inputs=label_text, outputs=labels) # parse list if user hits enter; ensures that list is fully parsed before classification
get_btn.click(fn=rand_image, outputs=im)
im.change(predict, inputs=[im, labels], outputs=cf)
reclass_btn.click(predict, inputs=[im, labels], outputs=cf)
with gr.Tab("Image Captioning"):
with gr.Row():
with gr.Column(variant="panel"):
im_cap = gr.Image(interactive=False).style(height=height)
with gr.Row():
get_btn_cap = gr.Button("Get Random Image").style(full_width=False)
caption_btn = gr.Button("Create Caption").style(full_width=False)
caption = gr.Text()
get_btn_cap.click(fn=rand_image, outputs=im_cap)
#im_cap.change(generate_text, inputs=im_cap, outputs=caption)
caption_btn.click(generate_text, inputs=im_cap, outputs=caption)
demo.queue()
demo.launch()