ryaalbr commited on
Commit
ee0cae7
1 Parent(s): 60c1ed2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from datasets import load_dataset, Image
3
+ import random
4
+ import numpy as np
5
+ from PIL import Image
6
+ from transformers import CLIPProcessor, CLIPModel
7
+
8
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
9
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
10
+
11
+ # Load the Unsplash dataset
12
+ dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train")
13
+
14
+ labels = []
15
+
16
+ def predict(image):
17
+ inputs = processor(text=[f"a photo of {c}" for c in labels], images=image, return_tensors="pt", padding=True)
18
+ outputs = model(**inputs)
19
+ logits_per_image = outputs.logits_per_image # this is the image-text similarity score
20
+ probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
21
+ return {k: float(v) for k, v in zip(labels, probs[0])}
22
+
23
+
24
+ def rand_image():
25
+ n = dataset.num_rows
26
+ r = random.randrange(0,n)
27
+ return dataset[r]["photo_image_url"]+"?h=250"
28
+
29
+ def set_labels(text):
30
+ global labels
31
+ labels = text.split(",")
32
+
33
+
34
+ with gr.Blocks() as demo:
35
+ instructions = """## Instructions:
36
+ 1. Enter list of labels separated by commas (or select one of the examples below)
37
+ 2. Click **Get Random Image** to grab random image from dataset and analyze it against the labels
38
+ 3. Click **Re-Classify Image** to re-run classification on current image after changing labels"""
39
+ gr.Markdown(instructions)
40
+ with gr.Row(variant="compact"):
41
+ label_text = gr.Textbox(show_label=False, placeholder="enter labels").style(container=False)
42
+ #submit_btn = gr.Button("Submit").style(full_width=False)
43
+ gr.Examples(["spring, summer, fall, winter", "mountain, city, beach, ocean, desert, plains", "red, blue, green, white, black, purple, brown", "food, vehicle, none of the above"], inputs=label_text)
44
+ with gr.Row():
45
+ with gr.Column(variant="panel"):
46
+ im = gr.Image(interactive=False).style(height=250)
47
+ with gr.Row():
48
+ get_btn = gr.Button("Get Random Image").style(full_width=False)
49
+ reclass_btn = gr.Button("Re-Classify Image").style(full_width=False)
50
+ cf = gr.Label()
51
+ #submit_btn.click(fn=set_labels, inputs=label_text)
52
+ label_text.change(fn=set_labels, inputs=label_text)
53
+ label_text.blur(fn=set_labels, inputs=label_text)
54
+ label_text.submit(fn=set_labels, inputs=label_text)
55
+ get_btn.click(fn=rand_image, outputs=im)
56
+ im.change(predict, inputs=im, outputs=cf)
57
+ reclass_btn.click(predict, inputs=im, outputs=cf)
58
+
59
+ demo.launch()