In [None]:
import panel as pn
pn.extension()
import requests
import random
import PIL
from PIL import Image
import io
from transformers import CLIPProcessor, CLIPModel
import numpy as np

In [None]:
pn.extension('texteditor', template="bootstrap", sizing_mode='stretch_width')

pn.state.template.param.update(
 main_max_width="690px",
 header_background="#F08080",
)

In [None]:
# File input widget
file_input = pn.widgets.FileInput()

# Button widget
compute_button = pn.widgets.Button(name="Compute")

# Text input widget
text_input = pn.widgets.TextInput(name='Possible class names (e.g., cat, dog)', placeholder='cat, dog')

In [None]:
def normalize_image(value, width=600):
 """
 normalize image to RBG channels and to the same size
 """
 if value: 
 b = io.BytesIO(value)
 image = PIL.Image.open(b).convert("RGB")
 else: 
 url = "http://images.cocodataset.org/val2017/000000039769.jpg"
 image = Image.open(requests.get(url, stream=True).raw)
 aspect = image.size[1] / image.size[0]
 height = int(aspect * width)
 return image.resize((width, height), PIL.Image.LANCZOS)

In [None]:
def image_classification(image):
 model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
 processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
 possible_categories = text_input.value.split(",")
 if text_input.value == '':
 possible_categories = ['cat', ' dog']
 inputs = processor(text=possible_categories, images=image, return_tensors="pt", padding=True)
 
 outputs = model(**inputs)
 logits_per_image = outputs.logits_per_image # this is the image-text similarity score
 probs = logits_per_image.softmax(dim=1)
 return probs.detach().numpy()

In [None]:
def get_result(_):
 image = normalize_image(file_input.value)

 result = image_classification(image)
 
 possible_categories = text_input.value.split(",")
 if text_input.value == '':
 possible_categories = ['cat', ' dog']

 progress_bars = pn.Column(*[
 pn.Row(
 possible_categories[i], 
 pn.indicators.Progress(name='', value=int(j*100), width=500))
 for i, j in enumerate(result[0])
 ])
 return progress_bars
 

In [None]:
# Bind the get_image function with the button widget
interactive_result = pn.bind(get_result, compute_button)

In [None]:
# layout
pn.Column(
 "## \U0001F60A Upload an image file and start classifying!",
 file_input,
 pn.bind(pn.panel, file_input),
 text_input, 
 compute_button,
 interactive_result
).servable(title="Panel Image Classification Demo")