ryaalbr commited on
Commit
f7512c4
1 Parent(s): a58968d

Update app.py

Browse files

Change labels from global variable to Gradio state component
Add comments
Move image height to variable
Update examples

Files changed (1) hide show
  1. app.py +17 -13
app.py CHANGED
@@ -5,33 +5,33 @@ import numpy as np
5
  from PIL import Image
6
  from transformers import CLIPProcessor, CLIPModel
7
 
 
8
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
9
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
10
 
11
  # Load the Unsplash dataset
12
- dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train")
13
 
14
- labels = []
15
 
16
- def predict(image):
17
  inputs = processor(text=[f"a photo of {c}" for c in labels], images=image, return_tensors="pt", padding=True)
18
  outputs = model(**inputs)
19
  logits_per_image = outputs.logits_per_image # this is the image-text similarity score
20
  probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
21
  return {k: float(v) for k, v in zip(labels, probs[0])}
22
 
23
-
24
  def rand_image():
25
  n = dataset.num_rows
26
  r = random.randrange(0,n)
27
- return dataset[r]["photo_image_url"]+"?h=250"
28
 
29
  def set_labels(text):
30
  global labels
31
  labels = text.split(",")
32
 
33
-
34
  with gr.Blocks() as demo:
 
35
  instructions = """## Instructions:
36
  1. Enter list of labels separated by commas (or select one of the examples below)
37
  2. Click **Get Random Image** to grab a random image from dataset and analyze it against the labels
@@ -40,20 +40,24 @@ with gr.Blocks() as demo:
40
  with gr.Row(variant="compact"):
41
  label_text = gr.Textbox(show_label=False, placeholder="enter labels").style(container=False)
42
  #submit_btn = gr.Button("Submit").style(full_width=False)
43
- gr.Examples(["spring, summer, fall, winter", "mountain, city, beach, ocean, desert, plains", "red, blue, green, white, black, purple, brown", "food, vehicle, none of the above"], inputs=label_text)
 
 
 
 
44
  with gr.Row():
45
  with gr.Column(variant="panel"):
46
- im = gr.Image(interactive=False).style(height=250)
47
  with gr.Row():
48
  get_btn = gr.Button("Get Random Image").style(full_width=False)
49
  reclass_btn = gr.Button("Re-Classify Image").style(full_width=False)
50
  cf = gr.Label()
51
  #submit_btn.click(fn=set_labels, inputs=label_text)
52
- label_text.change(fn=set_labels, inputs=label_text)
53
- label_text.blur(fn=set_labels, inputs=label_text)
54
- label_text.submit(fn=set_labels, inputs=label_text)
55
  get_btn.click(fn=rand_image, outputs=im)
56
- im.change(predict, inputs=im, outputs=cf)
57
- reclass_btn.click(predict, inputs=im, outputs=cf)
58
 
59
  demo.launch()
 
5
  from PIL import Image
6
  from transformers import CLIPProcessor, CLIPModel
7
 
8
+ # Load the pre-trained model and processor
9
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
10
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
11
 
12
  # Load the Unsplash dataset
13
+ dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train") # all 25K images are in train split
14
 
15
+ height = 250 # height for resizing images
16
 
17
+ def predict(image, labels):
18
  inputs = processor(text=[f"a photo of {c}" for c in labels], images=image, return_tensors="pt", padding=True)
19
  outputs = model(**inputs)
20
  logits_per_image = outputs.logits_per_image # this is the image-text similarity score
21
  probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
22
  return {k: float(v) for k, v in zip(labels, probs[0])}
23
 
 
24
  def rand_image():
25
  n = dataset.num_rows
26
  r = random.randrange(0,n)
27
+ return dataset[r]["photo_image_url"] + f"?h={height}" # Unsplash allows dynamic requests, including size of image
28
 
29
  def set_labels(text):
30
  global labels
31
  labels = text.split(",")
32
 
 
33
  with gr.Blocks() as demo:
34
+ labels = gr.State([]) # creates hidden component that can store a value and can be used as input/output; here, initial value is an empty list
35
  instructions = """## Instructions:
36
  1. Enter list of labels separated by commas (or select one of the examples below)
37
  2. Click **Get Random Image** to grab a random image from dataset and analyze it against the labels
 
40
  with gr.Row(variant="compact"):
41
  label_text = gr.Textbox(show_label=False, placeholder="enter labels").style(container=False)
42
  #submit_btn = gr.Button("Submit").style(full_width=False)
43
+ gr.Examples(["spring, summer, fall, winter",
44
+ "mountain, city, beach, ocean, desert, forest, valley",
45
+ "red, blue, green, white, black, purple, brown",
46
+ "person, animal, landscape, something else",
47
+ "day, night, dawn, dusk"], inputs=label_text)
48
  with gr.Row():
49
  with gr.Column(variant="panel"):
50
+ im = gr.Image(interactive=False).style(height=height)
51
  with gr.Row():
52
  get_btn = gr.Button("Get Random Image").style(full_width=False)
53
  reclass_btn = gr.Button("Re-Classify Image").style(full_width=False)
54
  cf = gr.Label()
55
  #submit_btn.click(fn=set_labels, inputs=label_text)
56
+ label_text.change(fn=set_labels, inputs=label_text, outputs=labels) # parse list if changed
57
+ label_text.blur(fn=set_labels, inputs=label_text, outputs=labels) # parse list if focus is moved elsewhere; ensures that list is fully parsed before classification
58
+ label_text.submit(fn=set_labels, inputs=label_text, outputs=lables) # parse list if user hits enter; ensures that list is fully parsed before classification
59
  get_btn.click(fn=rand_image, outputs=im)
60
+ im.change(predict, inputs=[im, labels], outputs=cf)
61
+ reclass_btn.click(predict, inputs=[im, labels], outputs=cf)
62
 
63
  demo.launch()