Spaces:
Running
Running
| import os | |
| from glob import glob | |
| from typing import Optional | |
| import gradio as gr | |
| import torch | |
| from torchvision.transforms.functional import resize, to_pil_image | |
| from transformers import AutoModel, CLIPProcessor | |
| PAPER_TITLE = "Vocabulary-free Image Classification" | |
| PAPER_URL = "https://arxiv.org/abs/2306.00917" | |
| MARKDOWN_DESCRIPTION = """ | |
| <div style="display: flex; align-items: center; justify-content: center; margin-bottom: 1rem;"> | |
| <h1>Vocabulary-free Image Classification</h1> | |
| </div> | |
| <div style="display: flex; | |
| flex-wrap: wrap; | |
| align-items: center; | |
| justify-content: center; | |
| margin-bottom: 1rem;"> | |
| <a href="https://github.com/altndrr/vic" style="margin-right: 0.5rem; margin-bottom: 0.5rem;"> | |
| <img src="https://img.shields.io/badge/code-github.altndrr%2Fvic-blue.svg"/> | |
| </a> | |
| <a href="https://huggingface.co/spaces/altndrr/vic" style="margin-right: 0.5rem; | |
| margin-bottom: 0.5rem;"> | |
| <img src="https://img.shields.io/badge/demo-hf.altndrr%2Fvic-yellow.svg"/> | |
| </a> | |
| <a href="https://arxiv.org/abs/2306.00917" style="margin-right: 0.5rem; | |
| margin-bottom: 0.5rem;"> | |
| <img src="https://img.shields.io/badge/paper-arXiv.2306.00917-B31B1B.svg"/> | |
| </a> | |
| <a href="https://alessandroconti.me/papers/2306.00917" style="margin-right: 0.5rem; | |
| margin-bottom: 0.5rem;"> | |
| <img src="https://img.shields.io/badge/website-gh--pages.altndrr%2Fvic-success.svg"/> | |
| </a> | |
| </div> | |
| """ | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MODEL = AutoModel.from_pretrained("altndrr/cased", trust_remote_code=True).to(DEVICE) | |
| PROCESSOR = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") | |
| def image_preprocess(image: gr.Image): | |
| if image is None: | |
| return None, None | |
| size = PROCESSOR.image_processor.size["shortest_edge"] | |
| size = min(size) if isinstance(size, tuple) else size | |
| image = resize(image, size) | |
| PROCESSOR.image_processor.do_normalize = False | |
| image_tensor = PROCESSOR(images=[image], return_tensors="pt", padding=True) | |
| PROCESSOR.image_processor.do_normalize = True | |
| image_tensor = image_tensor.pixel_values[0] | |
| curr_image = to_pil_image(image_tensor) | |
| return curr_image, image.copy() | |
| def image_inference(image: gr.Image, alpha: Optional[float] = None): | |
| if image is None: | |
| return None | |
| images = PROCESSOR(images=[image], return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| outputs = MODEL(images, alpha=alpha) | |
| vocabulary = outputs["vocabularies"][0] | |
| scores = outputs["scores"][0].tolist() | |
| confidences = dict(zip(vocabulary, scores)) | |
| return confidences | |
| with gr.Blocks(analytics_enabled=True, title=PAPER_TITLE, theme="soft") as demo: | |
| # LAYOUT | |
| gr.Markdown(MARKDOWN_DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(): | |
| curr_image = gr.Image( | |
| label="input", type="pil", sources=["upload", "webcam", "clipboard"] | |
| ) | |
| alpha_slider = gr.Slider(0.0, 1.0, value=0.7, step=0.1, label="alpha") | |
| with gr.Row(): | |
| clear_button = gr.Button(value="Clear", variant="secondary") | |
| run_button = gr.Button(value="Submit", variant="primary") | |
| with gr.Column(): | |
| output_label = gr.Label(label="output", num_top_classes=5) | |
| _orig_image = gr.Image(label="original image", type="pil", visible=False, interactive=False) | |
| _example_image = gr.Image(label="example image", type="pil", visible=False, interactive=False) | |
| examples = gr.Examples( | |
| examples=glob(os.path.join(os.path.dirname(__file__), "examples", "*.jpg")), | |
| inputs=[_example_image], | |
| outputs=[output_label], | |
| fn=image_inference, | |
| cache_examples=True, | |
| ) | |
| gr.Markdown(f"Check out the <a href={PAPER_URL}>original paper</a> for more information.") | |
| # INTERACTIONS | |
| # - change | |
| _example_image.change(image_preprocess, [_example_image], [curr_image, _orig_image]) | |
| # - upload | |
| curr_image.upload(image_preprocess, [curr_image], [curr_image, _orig_image]) | |
| curr_image.upload(lambda: None, [], [output_label]) | |
| # - clear | |
| curr_image.clear(lambda: (None, None), [], [_orig_image, output_label]) | |
| # - click | |
| clear_button.click(lambda: (None, None, None), [], [curr_image, _orig_image, output_label]) | |
| run_button.click(image_inference, [curr_image, alpha_slider], [output_label]) | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch() | |