sophiamyang commited on
Commit
307197c
β€’
1 Parent(s): ebf7741

async version of the image classifier

Browse files
Files changed (3) hide show
  1. Dockerfile +16 -0
  2. app.py +119 -0
  3. requirements.txt +5 -0
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ WORKDIR /code
4
+
5
+ COPY ./requirements.txt /code/requirements.txt
6
+ RUN python3 -m pip install --no-cache-dir --upgrade pip
7
+ RUN python3 -m pip install --no-cache-dir --upgrade -r /code/requirements.txt
8
+
9
+ COPY . .
10
+
11
+ CMD ["panel", "serve", "/code/app.py", "--address", "0.0.0.0", "--port", "7860", "--allow-websocket-origin", "*"]
12
+
13
+ RUN mkdir /.cache
14
+ RUN chmod 777 /.cache
15
+ RUN mkdir .chroma
16
+ RUN chmod 777 .chroma
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import aiohttp
2
+ import io
3
+ import random
4
+ import panel as pn
5
+
6
+ from PIL import Image
7
+
8
+ from transformers import CLIPProcessor, CLIPModel
9
+ from typing import List, Tuple
10
+
11
+ pn.extension(design='bootstrap', sizing_mode="stretch_width")
12
+
13
+ async def random_url(_):
14
+ api_url = random.choice([
15
+ "https://api.thecatapi.com/v1/images/search",
16
+ "https://api.thedogapi.com/v1/images/search"
17
+ ])
18
+ async with aiohttp.ClientSession() as session:
19
+ async with session.get(api_url) as resp:
20
+ return (await resp.json())[0]["url"]
21
+
22
+ @pn.cache
23
+ def load_processor_model(
24
+ processor_name: str, model_name: str
25
+ ) -> Tuple[CLIPProcessor, CLIPModel]:
26
+ processor = CLIPProcessor.from_pretrained(processor_name)
27
+ model = CLIPModel.from_pretrained(model_name)
28
+ return processor, model
29
+
30
+
31
+ async def open_image_url(image_url: str) -> Image:
32
+ async with aiohttp.ClientSession() as session:
33
+ async with session.get(image_url) as resp:
34
+ return Image.open(io.BytesIO(await resp.read()))
35
+
36
+
37
+ def get_similarity_scores(class_items: List[str], image: Image) -> List[float]:
38
+ processor, model = load_processor_model(
39
+ "openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32"
40
+ )
41
+ inputs = processor(
42
+ text=class_items,
43
+ images=[image],
44
+ return_tensors="pt", # pytorch tensors
45
+ )
46
+ outputs = model(**inputs)
47
+ logits_per_image = outputs.logits_per_image
48
+ class_likelihoods = logits_per_image.softmax(dim=1).detach().numpy()
49
+ return class_likelihoods[0]
50
+
51
+
52
+ async def process_inputs(class_names: List[str], image_url: str):
53
+ """
54
+ High level function that takes in the user inputs and returns the
55
+ classification results as panel objects.
56
+ """
57
+ if not image_url:
58
+ yield '## Provide an image URL'
59
+ return
60
+ yield '## Fetching image and running model βš™'
61
+ pil_img = await open_image_url(image_url)
62
+ img = pn.pane.Image(pil_img, height=400, align='center')
63
+
64
+ class_items = class_names.split(",")
65
+ class_likelihoods = get_similarity_scores(class_items, pil_img)
66
+
67
+ # build the results column
68
+ results = pn.Column("## πŸŽ‰ Here are the results!", img)
69
+
70
+ for class_item, class_likelihood in zip(class_items, class_likelihoods):
71
+ row_label = pn.widgets.StaticText(
72
+ name=class_item.strip(), value=f"{class_likelihood:.2%}", align='center'
73
+ )
74
+ row_bar = pn.indicators.Progress(
75
+ value=int(class_likelihood * 100),
76
+ sizing_mode="stretch_width",
77
+ bar_color="secondary",
78
+ margin=(0, 10),
79
+ design=pn.theme.Material
80
+ )
81
+ results.append(pn.Column(row_label, row_bar))
82
+ yield results
83
+
84
+ # create widgets
85
+ randomize_url = pn.widgets.Button(name="Randomize URL", align="end")
86
+
87
+ image_url = pn.widgets.TextInput(
88
+ name="Image URL to classify",
89
+ value=pn.bind(random_url, randomize_url),
90
+ )
91
+ class_names = pn.widgets.TextInput(
92
+ name="Comma separated class names",
93
+ placeholder="Enter possible class names, e.g. cat, dog",
94
+ value="cat, dog, parrot",
95
+ )
96
+
97
+ input_widgets = pn.Column(
98
+ "## 😊 Click randomize or paste a URL to start classifying!",
99
+ pn.Row(image_url, randomize_url),
100
+ class_names,
101
+ )
102
+
103
+ # add interactivity
104
+ interactive_result = pn.bind(
105
+ process_inputs, image_url=image_url, class_names=class_names
106
+ )
107
+
108
+ # create dashboard
109
+ main = pn.WidgetBox(
110
+ input_widgets,
111
+ interactive_result,
112
+ )
113
+
114
+ pn.template.BootstrapTemplate(
115
+ title="Panel Image Classification Demo",
116
+ main=main,
117
+ main_max_width="min(50%, 698px)",
118
+ header_background="#F08080",
119
+ ).servable(title="Panel Image Classification Demo");
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ panel
2
+ jupyter
3
+ transformers
4
+ numpy
5
+ torch