File size: 4,212 Bytes
ee0cae7
0d3a066
ee0cae7
 
 
c042a4f
ee0cae7
f7512c4
ee0cae7
 
 
 
f7512c4
ee0cae7
f7512c4
ee0cae7
f7512c4
ee0cae7
 
 
 
 
 
 
 
 
7f0dd85
ee0cae7
 
46228a6
ee0cae7
c042a4f
0d38b2c
 
 
ee0cae7
0d38b2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee0cae7
c297c95
ee0cae7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr
from datasets import load_dataset
import random
import numpy as np
from transformers import CLIPProcessor, CLIPModel
import os.environ

# Load the pre-trained model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Load the Unsplash dataset
dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train")  # all 25K images are in train split

height = 250   # height for resizing images

def predict(image, labels):
    inputs = processor(text=[f"a photo of {c}" for c in labels], images=image, return_tensors="pt", padding=True)
    outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image # this is the image-text similarity score
    probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
    return {k: float(v) for k, v in zip(labels, probs[0])}

def rand_image():
    n = dataset.num_rows
    r = random.randrange(0,n)
    return dataset[r]["photo_image_url"] # + f"?h={height}"  # Unsplash allows dynamic requests, including size of image

def set_labels(text):
    return text.split(",")
    
get_caption = gr.load("ryaalbr/caption", src="spaces", hf_token=os.environ("api_key"))
def generate_text(image):
    return get_caption(image)
    
with gr.Blocks() as demo:

    with gr.Tab("Zero-Shot Classification"):
        labels = gr.State([])   # creates hidden component that can store a value and can be used as input/output; here, initial value is an empty list
        instructions = """## Instructions:
        1. Enter list of labels separated by commas (or select one of the examples below)
        2. Click **Get Random Image** to grab a random image from dataset and analyze it against the labels
        3. Click **Re-Classify Image** to re-run classification on current image after changing labels"""
        gr.Markdown(instructions)
        with gr.Row(variant="compact"):
            label_text = gr.Textbox(show_label=False, placeholder="Enter classification labels").style(container=False)
            #submit_btn = gr.Button("Submit").style(full_width=False)
        gr.Examples(["spring, summer, fall, winter",
                     "mountain, city, beach, ocean, desert, forest, valley", 
                     "red, blue, green, white, black, purple, brown", 
                     "person, animal, landscape, something else",
                     "day, night, dawn, dusk"], inputs=label_text)
        with gr.Row():
            with gr.Column(variant="panel"):
                im = gr.Image(interactive=False).style(height=height)
                with gr.Row():
                    get_btn = gr.Button("Get Random Image").style(full_width=False)
                    reclass_btn = gr.Button("Re-Classify Image").style(full_width=False)
            cf = gr.Label()
        #submit_btn.click(fn=set_labels, inputs=label_text)
        label_text.change(fn=set_labels, inputs=label_text, outputs=labels)  # parse list if changed
        label_text.blur(fn=set_labels, inputs=label_text, outputs=labels)  # parse list if focus is moved elsewhere; ensures that list is fully parsed before classification
        label_text.submit(fn=set_labels, inputs=label_text, outputs=labels)  # parse list if user hits enter; ensures that list is fully parsed before classification
        get_btn.click(fn=rand_image, outputs=im)
        im.change(predict, inputs=[im, labels], outputs=cf)
        reclass_btn.click(predict, inputs=[im, labels], outputs=cf)

    with gr.Tab("Image Captioning"):
        with gr.Row():
            with gr.Column(variant="panel"):
                im_cap = gr.Image(interactive=False).style(height=height)
                with gr.Row():
                    get_btn_cap = gr.Button("Get Random Image").style(full_width=False)
                    caption_btn = gr.Button("Create Caption").style(full_width=False)
            caption = gr.Text()
        get_btn_cap.click(fn=rand_image, outputs=im_cap)
        #im_cap.change(generate_text, inputs=im_cap, outputs=caption)
        caption_btn.click(generate_text, inputs=im_cap, outputs=caption)        

demo.queue()
demo.launch()