Spaces:
Runtime error
Runtime error
Update app.py
Browse filesChange labels from global variable to Gradio state component
Add comments
Move image height to variable
Update examples
app.py
CHANGED
@@ -5,33 +5,33 @@ import numpy as np
|
|
5 |
from PIL import Image
|
6 |
from transformers import CLIPProcessor, CLIPModel
|
7 |
|
|
|
8 |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
9 |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
10 |
|
11 |
# Load the Unsplash dataset
|
12 |
-
dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train")
|
13 |
|
14 |
-
|
15 |
|
16 |
-
def predict(image):
|
17 |
inputs = processor(text=[f"a photo of {c}" for c in labels], images=image, return_tensors="pt", padding=True)
|
18 |
outputs = model(**inputs)
|
19 |
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
20 |
probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
21 |
return {k: float(v) for k, v in zip(labels, probs[0])}
|
22 |
|
23 |
-
|
24 |
def rand_image():
|
25 |
n = dataset.num_rows
|
26 |
r = random.randrange(0,n)
|
27 |
-
return dataset[r]["photo_image_url"]+"?h=
|
28 |
|
29 |
def set_labels(text):
|
30 |
global labels
|
31 |
labels = text.split(",")
|
32 |
|
33 |
-
|
34 |
with gr.Blocks() as demo:
|
|
|
35 |
instructions = """## Instructions:
|
36 |
1. Enter list of labels separated by commas (or select one of the examples below)
|
37 |
2. Click **Get Random Image** to grab a random image from dataset and analyze it against the labels
|
@@ -40,20 +40,24 @@ with gr.Blocks() as demo:
|
|
40 |
with gr.Row(variant="compact"):
|
41 |
label_text = gr.Textbox(show_label=False, placeholder="enter labels").style(container=False)
|
42 |
#submit_btn = gr.Button("Submit").style(full_width=False)
|
43 |
-
gr.Examples(["spring, summer, fall, winter",
|
|
|
|
|
|
|
|
|
44 |
with gr.Row():
|
45 |
with gr.Column(variant="panel"):
|
46 |
-
im = gr.Image(interactive=False).style(height=
|
47 |
with gr.Row():
|
48 |
get_btn = gr.Button("Get Random Image").style(full_width=False)
|
49 |
reclass_btn = gr.Button("Re-Classify Image").style(full_width=False)
|
50 |
cf = gr.Label()
|
51 |
#submit_btn.click(fn=set_labels, inputs=label_text)
|
52 |
-
label_text.change(fn=set_labels, inputs=label_text)
|
53 |
-
label_text.blur(fn=set_labels, inputs=label_text)
|
54 |
-
label_text.submit(fn=set_labels, inputs=label_text)
|
55 |
get_btn.click(fn=rand_image, outputs=im)
|
56 |
-
im.change(predict, inputs=im, outputs=cf)
|
57 |
-
reclass_btn.click(predict, inputs=im, outputs=cf)
|
58 |
|
59 |
demo.launch()
|
|
|
5 |
from PIL import Image
|
6 |
from transformers import CLIPProcessor, CLIPModel
|
7 |
|
8 |
+
# Load the pre-trained model and processor
|
9 |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
10 |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
11 |
|
12 |
# Load the Unsplash dataset
|
13 |
+
dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train") # all 25K images are in train split
|
14 |
|
15 |
+
height = 250 # height for resizing images
|
16 |
|
17 |
+
def predict(image, labels):
|
18 |
inputs = processor(text=[f"a photo of {c}" for c in labels], images=image, return_tensors="pt", padding=True)
|
19 |
outputs = model(**inputs)
|
20 |
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
21 |
probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
22 |
return {k: float(v) for k, v in zip(labels, probs[0])}
|
23 |
|
|
|
24 |
def rand_image():
|
25 |
n = dataset.num_rows
|
26 |
r = random.randrange(0,n)
|
27 |
+
return dataset[r]["photo_image_url"] + f"?h={height}" # Unsplash allows dynamic requests, including size of image
|
28 |
|
29 |
def set_labels(text):
|
30 |
global labels
|
31 |
labels = text.split(",")
|
32 |
|
|
|
33 |
with gr.Blocks() as demo:
|
34 |
+
labels = gr.State([]) # creates hidden component that can store a value and can be used as input/output; here, initial value is an empty list
|
35 |
instructions = """## Instructions:
|
36 |
1. Enter list of labels separated by commas (or select one of the examples below)
|
37 |
2. Click **Get Random Image** to grab a random image from dataset and analyze it against the labels
|
|
|
40 |
with gr.Row(variant="compact"):
|
41 |
label_text = gr.Textbox(show_label=False, placeholder="enter labels").style(container=False)
|
42 |
#submit_btn = gr.Button("Submit").style(full_width=False)
|
43 |
+
gr.Examples(["spring, summer, fall, winter",
|
44 |
+
"mountain, city, beach, ocean, desert, forest, valley",
|
45 |
+
"red, blue, green, white, black, purple, brown",
|
46 |
+
"person, animal, landscape, something else",
|
47 |
+
"day, night, dawn, dusk"], inputs=label_text)
|
48 |
with gr.Row():
|
49 |
with gr.Column(variant="panel"):
|
50 |
+
im = gr.Image(interactive=False).style(height=height)
|
51 |
with gr.Row():
|
52 |
get_btn = gr.Button("Get Random Image").style(full_width=False)
|
53 |
reclass_btn = gr.Button("Re-Classify Image").style(full_width=False)
|
54 |
cf = gr.Label()
|
55 |
#submit_btn.click(fn=set_labels, inputs=label_text)
|
56 |
+
label_text.change(fn=set_labels, inputs=label_text, outputs=labels) # parse list if changed
|
57 |
+
label_text.blur(fn=set_labels, inputs=label_text, outputs=labels) # parse list if focus is moved elsewhere; ensures that list is fully parsed before classification
|
58 |
+
label_text.submit(fn=set_labels, inputs=label_text, outputs=lables) # parse list if user hits enter; ensures that list is fully parsed before classification
|
59 |
get_btn.click(fn=rand_image, outputs=im)
|
60 |
+
im.change(predict, inputs=[im, labels], outputs=cf)
|
61 |
+
reclass_btn.click(predict, inputs=[im, labels], outputs=cf)
|
62 |
|
63 |
demo.launch()
|