altndrr commited on
Commit
fa73381
1 Parent(s): 0da80f6

Avoid double `prepare_image` when uploading

Browse files
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -4,7 +4,7 @@ from typing import Optional
4
 
5
  import gradio as gr
6
  import torch
7
- from torchvision.transforms.functional import to_pil_image
8
  from transformers import AutoModel, CLIPProcessor
9
 
10
  PAPER_TITLE = "Vocabulary-free Image Classification"
@@ -36,6 +36,17 @@ MODEL = AutoModel.from_pretrained("altndrr/cased", trust_remote_code=True).to(DE
36
  PROCESSOR = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
39
  def prepare_image(image: gr.Image):
40
  if image is None:
41
  return None, None
@@ -93,7 +104,7 @@ with gr.Blocks(analytics_enabled=True, title=PAPER_TITLE, theme="soft") as demo:
93
  _orig_image.change(prepare_image, [_orig_image], [curr_image, _orig_image])
94
 
95
  # - upload
96
- curr_image.upload(prepare_image, [curr_image], [curr_image, _orig_image])
97
  curr_image.upload(lambda: None, [], [output_label])
98
 
99
  # - clear
 
4
 
5
  import gradio as gr
6
  import torch
7
+ from torchvision.transforms.functional import resize, to_pil_image
8
  from transformers import AutoModel, CLIPProcessor
9
 
10
  PAPER_TITLE = "Vocabulary-free Image Classification"
 
36
  PROCESSOR = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
37
 
38
 
39
+ def save_original_image(image: gr.Image):
40
+ if image is None:
41
+ return None, None
42
+
43
+ size = PROCESSOR.image_processor.size["shortest_edge"]
44
+ size = min(size) if isinstance(size, tuple) else size
45
+ image = resize(image, size)
46
+
47
+ return image, image.copy()
48
+
49
+
50
  def prepare_image(image: gr.Image):
51
  if image is None:
52
  return None, None
 
104
  _orig_image.change(prepare_image, [_orig_image], [curr_image, _orig_image])
105
 
106
  # - upload
107
+ curr_image.upload(save_original_image, [curr_image], [curr_image, _orig_image])
108
  curr_image.upload(lambda: None, [], [output_label])
109
 
110
  # - clear