File size: 7,650 Bytes
ee0cae7
0d3a066
ee0cae7
 
 
d3a50f6
dbf8f9d
 
 
 
 
971e64d
 
 
 
 
 
 
ee0cae7
 
 
f7512c4
ee0cae7
741aba6
ee0cae7
f7512c4
ee0cae7
 
 
 
 
 
971e64d
 
 
 
 
 
 
 
 
 
 
ee0cae7
 
 
741aba6
ee0cae7
 
46228a6
ee0cae7
bee7306
31a5db6
 
bf143f5
dbf8f9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee0cae7
0d38b2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
971e64d
 
0d38b2c
 
 
 
31a5db6
a801546
0d38b2c
 
 
31a5db6
0d38b2c
 
bf143f5
 
 
 
 
 
6218a98
dbf8f9d
ee0cae7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import gradio as gr
from datasets import load_dataset
import random
import numpy as np
from transformers import CLIPProcessor, CLIPModel
from os import environ
import clip
import pickle
import requests
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

# # Load the pre-trained model and processor
# model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

orig_clip_model, orig_clip_processor = clip.load("ViT-B/32", device=device, jit=False)


# Load the Unsplash dataset
dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train")  # all 25K images are in train split

height = 256   # height for resizing images

def predict(image, labels):
    inputs = processor(text=[f"a photo of {c}" for c in labels], images=image, return_tensors="pt", padding=True)
    outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image # this is the image-text similarity score
    probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
    return {k: float(v) for k, v in zip(labels, probs[0])}


def predict2(image, labels):
    image = orig_clip_processor(img).unsqueeze(0).to(device)     
    text = clip.tokenize(labels).to(device)
    with torch.no_grad():
        image_features = orig_clip_model.encode_image(image)
        text_features = orig_clip_model.encode_text(text)
        logits_per_image, logits_per_text = orig_clip_model(image, text)
        probs = logits_per_image.softmax(dim=-1).cpu().numpy()
        return {k: float(v) for k, v in zip(labels, probs[0])}

def rand_image():
    n = dataset.num_rows
    r = random.randrange(0,n)
    return dataset[r]["photo_image_url"] + f"?h={height}"  # Unsplash allows dynamic requests, including size of image

def set_labels(text):
    return text.split(",")
    
get_caption = gr.load("ryaalbr/caption", src="spaces", hf_token=environ["api_key"])
def generate_text(image, model_name):
    return get_caption(image, model_name)

# get_images = gr.load("ryaalbr/ImageSearch", src="spaces", hf_token=environ["api_key"])
# def search_images(text):
#     return get_images(text, api_name="images")

emb_filename = 'unsplash-25k-photos-embeddings-indexes.pkl'
with open(emb_filename, 'rb') as emb:
        id2url, img_names, img_emb = pickle.load(emb)


def search(search_query):

    with torch.no_grad():
        # Encode and normalize the description using CLIP
        text_encoded = orig_clip_model.encode_text(clip.tokenize(search_query))
        text_encoded /= text_encoded.norm(dim=-1, keepdim=True)


    # Retrieve the description vector
    text_features = text_encoded.cpu().numpy()

    # Compute the similarity between the descrption and each photo using the Cosine similarity
    similarities = (text_features @ img_emb.T).squeeze(0)

    # Sort the photos by their similarity score
    best_photos = similarities.argsort()[::-1]
    best_photos = best_photos[:15]
    #best_photos = sorted(zip(similarities, range(img_emb.shape[0])), key=lambda x: x[0], reverse=True)

    best_photo_ids = img_names[best_photos]

    imgs = []

    # Iterate over the top 5 results
    for id in best_photo_ids:

        id, _ = id.split('.')
        url = id2url.get(id, "")
        if url == "": continue

        img = url  + "?h=512"
       # r = requests.get(url + "?w=512", stream=True)
       # img = Image.open(r.raw)
        #credits = f'Photo by <a href="https://unsplash.com/@{photo["photographer_username"]}?utm_source=NaturalLanguageImageSearch&utm_medium=referral">{photo["photographer_first_name"]} {photo["photographer_last_name"]}</a> on <a href="https://unsplash.com/?utm_source=NaturalLanguageImageSearch&utm_medium=referral">Unsplash</a>'
        imgs.append(img)
        #display(HTML(f'Photo by <a href="https://unsplash.com/@{photo["photographer_username"]}?utm_source=NaturalLanguageImageSearch&utm_medium=referral">{photo["photographer_first_name"]} {photo["photographer_last_name"]}</a> on <a href="https://unsplash.com/?utm_source=NaturalLanguageImageSearch&utm_medium=referral">Unsplash</a>'))

        if len(imgs) == 5: break

    return imgs





with gr.Blocks() as demo:

    with gr.Tab("Zero-Shot Classification"):
        labels = gr.State([])   # creates hidden component that can store a value and can be used as input/output; here, initial value is an empty list
        instructions = """## Instructions:
        1. Enter list of labels separated by commas (or select one of the examples below)
        2. Click **Get Random Image** to grab a random image from dataset and analyze it against the labels
        3. Click **Re-Classify Image** to re-run classification on current image after changing labels"""
        gr.Markdown(instructions)
        with gr.Row(variant="compact"):
            label_text = gr.Textbox(show_label=False, placeholder="Enter classification labels").style(container=False)
            #submit_btn = gr.Button("Submit").style(full_width=False)
        gr.Examples(["spring, summer, fall, winter",
                     "mountain, city, beach, ocean, desert, forest, valley", 
                     "red, blue, green, white, black, purple, brown", 
                     "person, animal, landscape, something else",
                     "day, night, dawn, dusk"], inputs=label_text)
        with gr.Row():
            with gr.Column(variant="panel"):
                im = gr.Image(interactive=False).style(height=height)
                with gr.Row():
                    get_btn = gr.Button("Get Random Image").style(full_width=False)
                    reclass_btn = gr.Button("Re-Classify Image").style(full_width=False)
            cf = gr.Label()
        #submit_btn.click(fn=set_labels, inputs=label_text)
        label_text.change(fn=set_labels, inputs=label_text, outputs=labels)  # parse list if changed
        label_text.blur(fn=set_labels, inputs=label_text, outputs=labels)  # parse list if focus is moved elsewhere; ensures that list is fully parsed before classification
        label_text.submit(fn=set_labels, inputs=label_text, outputs=labels)  # parse list if user hits enter; ensures that list is fully parsed before classification
        get_btn.click(fn=rand_image, outputs=im)
        im.change(predict2, inputs=[im, labels], outputs=cf)
        reclass_btn.click(predict2, inputs=[im, labels], outputs=cf)

    with gr.Tab("Image Captioning"):
        with gr.Row():
            with gr.Column(variant="panel"):
                im_cap = gr.Image(interactive=False, type='filepath').style(height=height)
                model_name = gr.Radio(choices=["COCO","Conceptual captions"], type="value", value="COCO", label="Model").style(container=True, item_container = False)
                with gr.Row():
                    get_btn_cap = gr.Button("Get Random Image").style(full_width=False)
                    caption_btn = gr.Button("Create Caption").style(full_width=False)
            caption = gr.Textbox(label='Caption')
        get_btn_cap.click(fn=rand_image, outputs=im_cap)
        #im_cap.change(generate_text, inputs=im_cap, outputs=caption)
        caption_btn.click(generate_text, inputs=[im_cap, model_name], outputs=caption)

    with gr.Tab("Image Search"):
        with gr.Column(variant="panel"):
            desc = gr.Textbox(show_label=False, placeholder="Enter description").style(container=False)
            search_btn = gr.Button("Find Images").style(full_width=False)
        gallery = gr.Gallery(show_label=False).style(grid=(2,2,3,5))
        search_btn.click(search,inputs=desc, outputs=gallery)

demo.launch()