Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from datasets import load_dataset, Image
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
from transformers import CLIPProcessor, CLIPModel
|
7 |
+
|
8 |
+
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
9 |
+
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
10 |
+
|
11 |
+
# Load the Unsplash dataset
|
12 |
+
dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train")
|
13 |
+
|
14 |
+
labels = []
|
15 |
+
|
16 |
+
def predict(image):
|
17 |
+
inputs = processor(text=[f"a photo of {c}" for c in labels], images=image, return_tensors="pt", padding=True)
|
18 |
+
outputs = model(**inputs)
|
19 |
+
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
20 |
+
probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
21 |
+
return {k: float(v) for k, v in zip(labels, probs[0])}
|
22 |
+
|
23 |
+
|
24 |
+
def rand_image():
|
25 |
+
n = dataset.num_rows
|
26 |
+
r = random.randrange(0,n)
|
27 |
+
return dataset[r]["photo_image_url"]+"?h=250"
|
28 |
+
|
29 |
+
def set_labels(text):
|
30 |
+
global labels
|
31 |
+
labels = text.split(",")
|
32 |
+
|
33 |
+
|
34 |
+
with gr.Blocks() as demo:
|
35 |
+
instructions = """## Instructions:
|
36 |
+
1. Enter list of labels separated by commas (or select one of the examples below)
|
37 |
+
2. Click **Get Random Image** to grab random image from dataset and analyze it against the labels
|
38 |
+
3. Click **Re-Classify Image** to re-run classification on current image after changing labels"""
|
39 |
+
gr.Markdown(instructions)
|
40 |
+
with gr.Row(variant="compact"):
|
41 |
+
label_text = gr.Textbox(show_label=False, placeholder="enter labels").style(container=False)
|
42 |
+
#submit_btn = gr.Button("Submit").style(full_width=False)
|
43 |
+
gr.Examples(["spring, summer, fall, winter", "mountain, city, beach, ocean, desert, plains", "red, blue, green, white, black, purple, brown", "food, vehicle, none of the above"], inputs=label_text)
|
44 |
+
with gr.Row():
|
45 |
+
with gr.Column(variant="panel"):
|
46 |
+
im = gr.Image(interactive=False).style(height=250)
|
47 |
+
with gr.Row():
|
48 |
+
get_btn = gr.Button("Get Random Image").style(full_width=False)
|
49 |
+
reclass_btn = gr.Button("Re-Classify Image").style(full_width=False)
|
50 |
+
cf = gr.Label()
|
51 |
+
#submit_btn.click(fn=set_labels, inputs=label_text)
|
52 |
+
label_text.change(fn=set_labels, inputs=label_text)
|
53 |
+
label_text.blur(fn=set_labels, inputs=label_text)
|
54 |
+
label_text.submit(fn=set_labels, inputs=label_text)
|
55 |
+
get_btn.click(fn=rand_image, outputs=im)
|
56 |
+
im.change(predict, inputs=im, outputs=cf)
|
57 |
+
reclass_btn.click(predict, inputs=im, outputs=cf)
|
58 |
+
|
59 |
+
demo.launch()
|