Spaces:
Running
Running
Avoid double `prepare_image` when uploading
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ from typing import Optional
|
|
4 |
|
5 |
import gradio as gr
|
6 |
import torch
|
7 |
-
from torchvision.transforms.functional import to_pil_image
|
8 |
from transformers import AutoModel, CLIPProcessor
|
9 |
|
10 |
PAPER_TITLE = "Vocabulary-free Image Classification"
|
@@ -36,6 +36,17 @@ MODEL = AutoModel.from_pretrained("altndrr/cased", trust_remote_code=True).to(DE
|
|
36 |
PROCESSOR = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
37 |
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
def prepare_image(image: gr.Image):
|
40 |
if image is None:
|
41 |
return None, None
|
@@ -93,7 +104,7 @@ with gr.Blocks(analytics_enabled=True, title=PAPER_TITLE, theme="soft") as demo:
|
|
93 |
_orig_image.change(prepare_image, [_orig_image], [curr_image, _orig_image])
|
94 |
|
95 |
# - upload
|
96 |
-
curr_image.upload(
|
97 |
curr_image.upload(lambda: None, [], [output_label])
|
98 |
|
99 |
# - clear
|
|
|
4 |
|
5 |
import gradio as gr
|
6 |
import torch
|
7 |
+
from torchvision.transforms.functional import resize, to_pil_image
|
8 |
from transformers import AutoModel, CLIPProcessor
|
9 |
|
10 |
PAPER_TITLE = "Vocabulary-free Image Classification"
|
|
|
36 |
PROCESSOR = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
37 |
|
38 |
|
39 |
+
def save_original_image(image: gr.Image):
|
40 |
+
if image is None:
|
41 |
+
return None, None
|
42 |
+
|
43 |
+
size = PROCESSOR.image_processor.size["shortest_edge"]
|
44 |
+
size = min(size) if isinstance(size, tuple) else size
|
45 |
+
image = resize(image, size)
|
46 |
+
|
47 |
+
return image, image.copy()
|
48 |
+
|
49 |
+
|
50 |
def prepare_image(image: gr.Image):
|
51 |
if image is None:
|
52 |
return None, None
|
|
|
104 |
_orig_image.change(prepare_image, [_orig_image], [curr_image, _orig_image])
|
105 |
|
106 |
# - upload
|
107 |
+
curr_image.upload(save_original_image, [curr_image], [curr_image, _orig_image])
|
108 |
curr_image.upload(lambda: None, [], [output_label])
|
109 |
|
110 |
# - clear
|