import os from glob import glob from typing import Optional import gradio as gr import torch from torchvision.transforms.functional import resize, to_pil_image from transformers import AutoModel, CLIPProcessor PAPER_TITLE = "Vocabulary-free Image Classification" PAPER_URL = "https://arxiv.org/abs/2306.00917" MARKDOWN_DESCRIPTION = """

Vocabulary-free Image Classification

""" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MODEL = AutoModel.from_pretrained("altndrr/cased", trust_remote_code=True).to(DEVICE) PROCESSOR = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") def save_original_image(image: gr.Image): if image is None: return None, None size = PROCESSOR.image_processor.size["shortest_edge"] size = min(size) if isinstance(size, tuple) else size image = resize(image, size) return image, image.copy() def prepare_image(image: gr.Image): if image is None: return None, None PROCESSOR.image_processor.do_normalize = False image_tensor = PROCESSOR(images=[image], return_tensors="pt", padding=True) PROCESSOR.image_processor.do_normalize = True image_tensor = image_tensor.pixel_values[0] curr_image = to_pil_image(image_tensor) return curr_image, image.copy() def image_inference(image: gr.Image, alpha: Optional[float] = None): if image is None: return None images = PROCESSOR(images=[image], return_tensors="pt", padding=True) with torch.no_grad(): outputs = MODEL(images, alpha=alpha) vocabulary = outputs["vocabularies"][0] scores = outputs["scores"][0].tolist() confidences = dict(zip(vocabulary, scores)) return confidences with gr.Blocks(analytics_enabled=True, title=PAPER_TITLE, theme="soft") as demo: # LAYOUT gr.Markdown(MARKDOWN_DESCRIPTION) with gr.Row(): with gr.Column(): curr_image = gr.Image(label="input", type="pil") _orig_image = gr.Image( label="orig. image", type="pil", visible=False, interactive=False ) alpha_slider = gr.Slider(0.0, 1.0, value=0.7, step=0.1, label="alpha") with gr.Row(): clear_button = gr.Button(value="Clear", variant="secondary") run_button = gr.Button(value="Submit", variant="primary") with gr.Column(): output_label = gr.Label(label="output", num_top_classes=5) examples = gr.Examples( examples=glob(os.path.join(os.path.dirname(__file__), "examples", "*.jpg")), inputs=[_orig_image], outputs=[output_label], fn=image_inference, cache_examples=True, ) gr.Markdown(f"Check out the original paper for more information.") # INTERACTIONS # - change _orig_image.change(prepare_image, [_orig_image], [curr_image, _orig_image]) # - upload curr_image.upload(save_original_image, [curr_image], [curr_image, _orig_image]) curr_image.upload(lambda: None, [], [output_label]) # - clear curr_image.clear(lambda: (None, None), [], [_orig_image, output_label]) # - click clear_button.click(lambda: (None, None, None), [], [curr_image, _orig_image, output_label]) run_button.click(image_inference, [curr_image, alpha_slider], [output_label]) if __name__ == "__main__": demo.queue() demo.launch()