dwb2023 commited on
Commit
5dae26f
·
verified ·
1 Parent(s): 7faba42

Upload 5 files

Browse files
Files changed (5) hide show
  1. README.md +3 -3
  2. app.py +281 -0
  3. inference.py +27 -0
  4. utils.py +0 -0
  5. vae-oid.npz +3 -0
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Cellvisionai
3
- emoji: 🌍
4
  colorFrom: indigo
5
- colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 4.36.1
8
  app_file: app.py
 
1
  ---
2
+ title: CellVision AI -- Intelligent Cell Imaging Analysis
3
+ emoji: 🔬🧫
4
  colorFrom: indigo
5
+ colorTo: gray
6
  sdk: gradio
7
  sdk_version: 4.36.1
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import re
3
+
4
+ import PIL.Image
5
+ import gradio as gr
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+
10
+ import flax.linen as nn
11
+ from inference import PaliGemmaModel as pali_gemma_model
12
+
13
+ COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']
14
+
15
+ ##### Parse segmentation output tokens into masks
16
+ ##### Also returns bounding boxes with their labels
17
+
18
+ def parse_segmentation(input_image, input_text):
19
+ out = pali_gemma_model.infer(input_image, input_text, max_new_tokens=100)
20
+ objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True)
21
+ labels = set(obj.get('name') for obj in objs if obj.get('name'))
22
+ color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
23
+ highlighted_text = [(obj['content'], obj.get('name')) for obj in objs]
24
+ annotated_img = (
25
+ input_image,
26
+ [
27
+ (
28
+ obj['mask'] if obj.get('mask') is not None else obj['xyxy'],
29
+ obj['name'] or '',
30
+ )
31
+ for obj in objs
32
+ if 'mask' in obj or 'xyxy' in obj
33
+ ],
34
+ )
35
+ has_annotations = bool(annotated_img[1])
36
+ return annotated_img
37
+
38
+ INTRO_TEXT="🔬🧠 CellVision AI -- Intelligent Cell Imaging Analysis 🤖🧫"
39
+ IMAGE_PROMPT="""
40
+ Describe the morphological characteristics and visible interactions between different cell types.
41
+ Assess the biological context to identify signs of cancer and the presence of antigens.
42
+ """
43
+
44
+ with gr.Blocks(css="style.css") as demo:
45
+ gr.Markdown(INTRO_TEXT)
46
+ with gr.Tab("Segment/Detect"):
47
+ with gr.Row():
48
+ with gr.Column():
49
+ image = gr.Image(type="pil")
50
+ seg_input = gr.Text(label="Entities to Segment/Detect")
51
+
52
+ with gr.Column():
53
+ annotated_image = gr.AnnotatedImage(label="Output")
54
+
55
+ seg_btn = gr.Button("Submit")
56
+ examples = [["./examples/cart1.jpg", "segment cells"],
57
+ ["./examples/cart1.jpg", "detect cells"],
58
+ ["./examples/cart2.jpg", "segment cells"],
59
+ ["./examples/cart2.jpg", "detect cells"],
60
+ ["./examples/cart3.jpg", "segment cells"],
61
+ ["./examples/cart3.jpg", "detect cells"]]
62
+ gr.Examples(
63
+ examples=examples,
64
+ inputs=[image, seg_input],
65
+ )
66
+ seg_inputs = [
67
+ image,
68
+ seg_input
69
+ ]
70
+ seg_outputs = [
71
+ annotated_image
72
+ ]
73
+ seg_btn.click(
74
+ fn=parse_segmentation,
75
+ inputs=seg_inputs,
76
+ outputs=seg_outputs,
77
+ )
78
+ with gr.Tab("Text Generation"):
79
+ with gr.Column():
80
+ image = gr.Image(type="pil")
81
+ text_input = gr.Text(label="Input Text")
82
+
83
+ text_output = gr.Text(label="Text Output")
84
+ chat_btn = gr.Button()
85
+ tokens = gr.Slider(
86
+ label="Max New Tokens",
87
+ info="Set to larger for longer generation.",
88
+ minimum=10,
89
+ maximum=100,
90
+ value=50,
91
+ step=10,
92
+ )
93
+
94
+ chat_inputs = [
95
+ image,
96
+ text_input,
97
+ tokens
98
+ ]
99
+ chat_outputs = [
100
+ text_output
101
+ ]
102
+ chat_btn.click(
103
+ fn=pali_gemma_model.infer,
104
+ inputs=chat_inputs,
105
+ outputs=chat_outputs,
106
+ )
107
+
108
+ examples = [["./examples/cart1.jpg", IMAGE_PROMPT],
109
+ ["./examples/cart2.jpg", IMAGE_PROMPT],
110
+ ["./examples/cart3.jpg", IMAGE_PROMPT]]
111
+ gr.Examples(
112
+ examples=examples,
113
+ inputs=chat_inputs,
114
+ )
115
+
116
+
117
+ ### Postprocessing Utils for Segmentation Tokens
118
+ ### Segmentation tokens are passed to another VAE which decodes them to a mask
119
+
120
+ _MODEL_PATH = 'vae-oid.npz'
121
+
122
+ _SEGMENT_DETECT_RE = re.compile(
123
+ r'(.*?)' +
124
+ r'<loc(\d{4})>' * 4 + r'\s*' +
125
+ '(?:%s)?' % (r'<seg(\d{3})>' * 16) +
126
+ r'\s*([^;<>]+)? ?(?:; )?',
127
+ )
128
+
129
+
130
+ def _get_params(checkpoint):
131
+ """Converts PyTorch checkpoint to Flax params."""
132
+
133
+ def transp(kernel):
134
+ return np.transpose(kernel, (2, 3, 1, 0))
135
+
136
+ def conv(name):
137
+ return {
138
+ 'bias': checkpoint[name + '.bias'],
139
+ 'kernel': transp(checkpoint[name + '.weight']),
140
+ }
141
+
142
+ def resblock(name):
143
+ return {
144
+ 'Conv_0': conv(name + '.0'),
145
+ 'Conv_1': conv(name + '.2'),
146
+ 'Conv_2': conv(name + '.4'),
147
+ }
148
+
149
+ return {
150
+ '_embeddings': checkpoint['_vq_vae._embedding'],
151
+ 'Conv_0': conv('decoder.0'),
152
+ 'ResBlock_0': resblock('decoder.2.net'),
153
+ 'ResBlock_1': resblock('decoder.3.net'),
154
+ 'ConvTranspose_0': conv('decoder.4'),
155
+ 'ConvTranspose_1': conv('decoder.6'),
156
+ 'ConvTranspose_2': conv('decoder.8'),
157
+ 'ConvTranspose_3': conv('decoder.10'),
158
+ 'Conv_1': conv('decoder.12'),
159
+ }
160
+
161
+
162
+ def _quantized_values_from_codebook_indices(codebook_indices, embeddings):
163
+ batch_size, num_tokens = codebook_indices.shape
164
+ assert num_tokens == 16, codebook_indices.shape
165
+ unused_num_embeddings, embedding_dim = embeddings.shape
166
+
167
+ encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0)
168
+ encodings = encodings.reshape((batch_size, 4, 4, embedding_dim))
169
+ return encodings
170
+
171
+
172
+ @functools.cache
173
+ def _get_reconstruct_masks():
174
+ """Reconstructs masks from codebook indices.
175
+ Returns:
176
+ A function that expects indices shaped `[B, 16]` of dtype int32, each
177
+ ranging from 0 to 127 (inclusive), and that returns a decoded masks sized
178
+ `[B, 64, 64, 1]`, of dtype float32, in range [-1, 1].
179
+ """
180
+
181
+ class ResBlock(nn.Module):
182
+ features: int
183
+
184
+ @nn.compact
185
+ def __call__(self, x):
186
+ original_x = x
187
+ x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
188
+ x = nn.relu(x)
189
+ x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
190
+ x = nn.relu(x)
191
+ x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x)
192
+ return x + original_x
193
+
194
+ class Decoder(nn.Module):
195
+ """Upscales quantized vectors to mask."""
196
+
197
+ @nn.compact
198
+ def __call__(self, x):
199
+ num_res_blocks = 2
200
+ dim = 128
201
+ num_upsample_layers = 4
202
+
203
+ x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x)
204
+ x = nn.relu(x)
205
+
206
+ for _ in range(num_res_blocks):
207
+ x = ResBlock(features=dim)(x)
208
+
209
+ for _ in range(num_upsample_layers):
210
+ x = nn.ConvTranspose(
211
+ features=dim,
212
+ kernel_size=(4, 4),
213
+ strides=(2, 2),
214
+ padding=2,
215
+ transpose_kernel=True,
216
+ )(x)
217
+ x = nn.relu(x)
218
+ dim //= 2
219
+
220
+ x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x)
221
+
222
+ return x
223
+
224
+ def reconstruct_masks(codebook_indices):
225
+ quantized = _quantized_values_from_codebook_indices(
226
+ codebook_indices, params['_embeddings']
227
+ )
228
+ return Decoder().apply({'params': params}, quantized)
229
+
230
+ with open(_MODEL_PATH, 'rb') as f:
231
+ params = _get_params(dict(np.load(f)))
232
+
233
+ return jax.jit(reconstruct_masks, backend='cpu')
234
+
235
+ def extract_objs(text, width, height, unique_labels=False):
236
+ """Returns objs for a string with "<loc>" and "<seg>" tokens."""
237
+ objs = []
238
+ seen = set()
239
+ while text:
240
+ m = _SEGMENT_DETECT_RE.match(text)
241
+ if not m:
242
+ break
243
+ print("m", m)
244
+ gs = list(m.groups())
245
+ before = gs.pop(0)
246
+ name = gs.pop()
247
+ y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
248
+
249
+ y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width))
250
+ seg_indices = gs[4:20]
251
+ if seg_indices[0] is None:
252
+ mask = None
253
+ else:
254
+ seg_indices = np.array([int(x) for x in seg_indices], dtype=np.int32)
255
+ m64, = _get_reconstruct_masks()(seg_indices[None])[..., 0]
256
+ m64 = np.clip(np.array(m64) * 0.5 + 0.5, 0, 1)
257
+ m64 = PIL.Image.fromarray((m64 * 255).astype('uint8'))
258
+ mask = np.zeros([height, width])
259
+ if y2 > y1 and x2 > x1:
260
+ mask[y1:y2, x1:x2] = np.array(m64.resize([x2 - x1, y2 - y1])) / 255.0
261
+
262
+ content = m.group()
263
+ if before:
264
+ objs.append(dict(content=before))
265
+ content = content[len(before):]
266
+ while unique_labels and name in seen:
267
+ name = (name or '') + "'"
268
+ seen.add(name)
269
+ objs.append(dict(
270
+ content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name))
271
+ text = text[len(before) + len(content):]
272
+
273
+ if text:
274
+ objs.append(dict(content=text))
275
+
276
+ return objs
277
+
278
+ #########
279
+
280
+ if __name__ == "__main__":
281
+ demo.queue(max_size=10).launch(debug=True)
inference.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from huggingface_hub import login
4
+ from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
5
+ import spaces
6
+
7
+ hf_token = os.getenv("HF_TOKEN")
8
+ login(token=hf_token, add_to_git_credential=True)
9
+
10
+ class PaliGemmaModel:
11
+ def __init__(self):
12
+ self.model_id = "google/paligemma-3b-mix-448"
13
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ self.model = PaliGemmaForConditionalGeneration.from_pretrained(self.model_id).eval().to(self.device)
15
+ self.processor = PaliGemmaProcessor.from_pretrained(self.model_id)
16
+
17
+ @spaces.GPU
18
+ def infer(self, image: PIL.Image.Image, text: str, max_new_tokens: int) -> str:
19
+ inputs = self.processor(text=text, images=image, return_tensors="pt").to(self.device)
20
+ with torch.inference_mode():
21
+ generated_ids = self.model.generate(
22
+ **inputs,
23
+ max_new_tokens=max_new_tokens,
24
+ do_sample=False
25
+ )
26
+ result = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
27
+ return result[0][len(text):].lstrip("\n")
utils.py ADDED
File without changes
vae-oid.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5586010257b8536dddefab65e7755077f21d5672d5674dacf911f73ae95a4447
3
+ size 8479556