import aiohttp import io import random import panel as pn from PIL import Image from transformers import CLIPProcessor, CLIPModel from typing import List, Tuple pn.extension(design='bootstrap', sizing_mode="stretch_width") async def random_url(_): api_url = random.choice([ "https://api.thecatapi.com/v1/images/search", "https://api.thedogapi.com/v1/images/search" ]) async with aiohttp.ClientSession() as session: async with session.get(api_url) as resp: return (await resp.json())[0]["url"] @pn.cache def load_processor_model( processor_name: str, model_name: str ) -> Tuple[CLIPProcessor, CLIPModel]: processor = CLIPProcessor.from_pretrained(processor_name) model = CLIPModel.from_pretrained(model_name) return processor, model async def open_image_url(image_url: str) -> Image: async with aiohttp.ClientSession() as session: async with session.get(image_url) as resp: return Image.open(io.BytesIO(await resp.read())) def get_similarity_scores(class_items: List[str], image: Image) -> List[float]: processor, model = load_processor_model( "openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32" ) inputs = processor( text=class_items, images=[image], return_tensors="pt", # pytorch tensors ) outputs = model(**inputs) logits_per_image = outputs.logits_per_image class_likelihoods = logits_per_image.softmax(dim=1).detach().numpy() return class_likelihoods[0] async def process_inputs(class_names: List[str], image_url: str): """ High level function that takes in the user inputs and returns the classification results as panel objects. """ if not image_url: yield '## Provide an image URL' return yield '## Fetching image and running model ⚙' pil_img = await open_image_url(image_url) img = pn.pane.Image(pil_img, height=400, align='center') class_items = class_names.split(",") class_likelihoods = get_similarity_scores(class_items, pil_img) # build the results column results = pn.Column("## 🎉 Here are the results!", img) for class_item, class_likelihood in zip(class_items, class_likelihoods): row_label = pn.widgets.StaticText( name=class_item.strip(), value=f"{class_likelihood:.2%}", align='center' ) row_bar = pn.indicators.Progress( value=int(class_likelihood * 100), sizing_mode="stretch_width", bar_color="secondary", margin=(0, 10), design=pn.theme.Material ) results.append(pn.Column(row_label, row_bar)) yield results # create widgets randomize_url = pn.widgets.Button(name="Randomize URL", align="end") image_url = pn.widgets.TextInput( name="Image URL to classify", value=pn.bind(random_url, randomize_url), ) class_names = pn.widgets.TextInput( name="Comma separated class names", placeholder="Enter possible class names, e.g. cat, dog", value="cat, dog, parrot", ) input_widgets = pn.Column( "## 😊 Click randomize or paste a URL to start classifying!", pn.Row(image_url, randomize_url), class_names, ) # add interactivity interactive_result = pn.bind( process_inputs, image_url=image_url, class_names=class_names ) # create dashboard main = pn.WidgetBox( input_widgets, interactive_result, ) pn.template.BootstrapTemplate( title="Panel Image Classification Demo", main=main, main_max_width="min(50%, 698px)", header_background="#F08080", ).servable(title="Panel Image Classification Demo");