drhead commited on
Commit
37183bc
·
1 Parent(s): d8bb729

Add attention visualization + other updates

Browse files
JTP_PILOT-e4-vit_so400m_patch14_siglip_384.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2bac5c99e38e946b09b8813e28598783b2aabbea24ecafd04261142343185f69
3
- size 1754826116
 
 
 
 
JTP_PILOT2-2-e3-vit_so400m_patch14_siglip_384.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:565cbd6d3f453940c12d73aa2496bab102caf9f1c9a2a85433533c768df03555
3
- size 1796716928
 
 
 
 
JTP_PILOT2-e3-vit_so400m_patch14_siglip_384.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0ac0e46bc773cfb486a83a79de9497566d91359e7962d225afdb7822dffc603d
3
- size 1796716928
 
 
 
 
app.py CHANGED
@@ -1,11 +1,12 @@
1
 
2
- import json
3
  import os
4
  import zipfile
5
  from io import BytesIO
6
  from tempfile import NamedTemporaryFile
7
  import tempfile
8
-
 
9
  import gradio as gr
10
  import pandas as pd
11
  from PIL import Image
@@ -23,8 +24,7 @@ from typing import Callable
23
  from functools import partial
24
  import spaces.config
25
  from spaces.zero.decorator import P, R
26
-
27
- torch.set_grad_enabled(False)
28
 
29
  def _dynGPU(
30
  fn: Callable[P, R] | None, duration: Callable[P, int], min=10, max=300, step=5
@@ -189,47 +189,130 @@ class GatedHead(torch.nn.Module):
189
 
190
  model.head = GatedHead(min(model.head.weight.shape), 9083)
191
 
192
- safetensors.torch.load_model(model, "JTP_PILOT2-2-e3-vit_so400m_patch14_siglip_384.safetensors")
 
 
 
 
 
 
 
193
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
194
- model.to(device)
 
195
  model.eval()
196
 
197
- with open("tagger_tags.json", "r") as file:
198
- tags = json.load(file) # type: dict
199
- allowed_tags = list(tags.keys())
200
 
201
- for idx, tag in enumerate(allowed_tags):
202
- allowed_tags[idx] = tag.replace("_", " ")
203
 
204
- sorted_tag_score = {}
205
 
206
  @spaces.GPU(duration=6)
207
- def run_classifier(image, threshold):
208
- global sorted_tag_score
209
  img = image.convert('RGBA')
210
- tensor = transform(img).unsqueeze(0).to(device)
211
 
212
  with torch.no_grad():
213
- probits = model(tensor)[0]
214
- values, indices = probits.topk(250)
 
 
215
 
216
- tag_score = dict()
217
- for i in range(indices.size(0)):
218
- tag_score[allowed_tags[indices[i]]] = values[i].item()
219
  sorted_tag_score = dict(sorted(tag_score.items(), key=lambda item: item[1], reverse=True))
220
 
221
- return create_tags(threshold)
222
 
223
- def create_tags(threshold):
224
- global sorted_tag_score
225
  filtered_tag_score = {key: value for key, value in sorted_tag_score.items() if value > threshold}
226
  text_no_impl = ", ".join(filtered_tag_score.keys())
227
  return text_no_impl, filtered_tag_score
228
 
229
  def clear_image():
230
- global sorted_tag_score
231
- sorted_tag_score = {}
232
- return "", {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
  class ImageDataset(Dataset):
235
  def __init__(self, image_files, transform):
@@ -311,35 +394,84 @@ def process_zip(zip_file, threshold):
311
 
312
  return temp_file.name, df
313
 
314
- with gr.Blocks(css=".output-class { display: none; }") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  gr.Markdown("""
316
  ## Joint Tagger Project: JTP-PILOT² Demo **BETA**
317
- This tagger is designed for use on furry images (though may very well work on out-of-distribution images, potentially with funny results). A threshold of 0.2 is recommended. Lower thresholds often turn up more valid tags, but can also result in some amount of hallucinated tags.
318
-
319
- This tagger is the result of joint efforts between members of the RedRocket team, with distinctions given to Thessalo for creating the foundation for this project with his efforts, RedHotTensors for redesigning the process into a second-order method that models information expectation, and drhead for dataset prep, creation of training code and supervision of training runs.
320
-
321
- Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
322
  """)
323
  with gr.Tabs():
324
  with gr.TabItem("Single Image"):
 
 
 
325
  with gr.Row():
326
  with gr.Column():
327
- image_input = gr.Image(label="Source", sources=['upload'], type='pil', height=512, show_label=False)
328
- threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Threshold")
 
329
  with gr.Column():
330
  tag_string = gr.Textbox(label="Tag String")
 
331
  label_box = gr.Label(label="Tag Predictions", num_top_classes=250, show_label=False)
332
 
333
- image_input.upload(
334
  fn=run_classifier,
335
- inputs=[image_input, threshold_slider],
336
- outputs=[tag_string, label_box]
 
 
 
 
 
 
 
337
  )
338
 
339
  threshold_slider.input(
340
  fn=create_tags,
341
- inputs=[threshold_slider],
342
- outputs=[tag_string, label_box]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  )
344
 
345
  with gr.TabItem("Multiple Images"):
@@ -357,6 +489,11 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
357
  inputs=[zip_input, multi_threshold_slider],
358
  outputs=[zip_output, dataframe_output]
359
  )
 
 
 
 
 
360
 
361
  if __name__ == "__main__":
362
  demo.launch()
 
1
 
2
+ import msgspec
3
  import os
4
  import zipfile
5
  from io import BytesIO
6
  from tempfile import NamedTemporaryFile
7
  import tempfile
8
+ import numpy as np
9
+ import matplotlib.cm as cm
10
  import gradio as gr
11
  import pandas as pd
12
  from PIL import Image
 
24
  from functools import partial
25
  import spaces.config
26
  from spaces.zero.decorator import P, R
27
+ from huggingface_hub import hf_hub_download
 
28
 
29
  def _dynGPU(
30
  fn: Callable[P, R] | None, duration: Callable[P, int], min=10, max=300, step=5
 
189
 
190
  model.head = GatedHead(min(model.head.weight.shape), 9083)
191
 
192
+ cached_model = hf_hub_download(
193
+ repo_id="RedRocket/JointTaggerProject",
194
+ subfolder="JTP_PILOT2",
195
+ filename="JTP_PILOT2-e3-vit_so400m_patch14_siglip_384.safetensors"
196
+ )
197
+
198
+ safetensors.torch.load_model(model, cached_model)
199
+
200
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
201
+ if torch.cuda.is_available():
202
+ model.to(device='cuda', dtype=torch.float16, memory_format=torch.channels_last)
203
  model.eval()
204
 
205
+ with open("tagger_tags.json", "rb") as file:
206
+ tags = msgspec.json.decode(file.read(), type=dict[str, int])
 
207
 
208
+ for tag in tags.keys():
209
+ tags[tag.replace("_", " ")] = tags.pop(tag)
210
 
211
+ allowed_tags = list(tags.keys())
212
 
213
  @spaces.GPU(duration=6)
214
+ def run_classifier(image: Image.Image, threshold):
 
215
  img = image.convert('RGBA')
216
+ tensor = transform(img).unsqueeze(0)
217
 
218
  with torch.no_grad():
219
+ probits = model(tensor)[0] # type: torch.Tensor
220
+ values, indices = probits.cpu().topk(250)
221
+
222
+ tag_score = {allowed_tags[idx.item()]: val.item() for idx, val in zip(indices, values)}
223
 
 
 
 
224
  sorted_tag_score = dict(sorted(tag_score.items(), key=lambda item: item[1], reverse=True))
225
 
226
+ return *create_tags(threshold, sorted_tag_score), img, sorted_tag_score
227
 
228
+ def create_tags(threshold, sorted_tag_score: dict):
 
229
  filtered_tag_score = {key: value for key, value in sorted_tag_score.items() if value > threshold}
230
  text_no_impl = ", ".join(filtered_tag_score.keys())
231
  return text_no_impl, filtered_tag_score
232
 
233
  def clear_image():
234
+ return "", {}, None, {}, None
235
+
236
+ @spaces.GPU(duration=5)
237
+ def cam_inference(img, threshold, alpha, evt: gr.SelectData):
238
+ target_tag_index = tags[evt.value]
239
+ tensor = transform(img).unsqueeze(0)
240
+
241
+ gradients = {}
242
+ activations = {}
243
+
244
+ def hook_forward(module, input, output):
245
+ activations['value'] = output
246
+
247
+ def hook_backward(module, grad_in, grad_out):
248
+ gradients['value'] = grad_out[0]
249
+
250
+ handle_forward = model.norm.register_forward_hook(hook_forward)
251
+ handle_backward = model.norm.register_full_backward_hook(hook_backward)
252
+
253
+ probits = model(tensor)[0]
254
+
255
+ model.zero_grad()
256
+ probits[target_tag_index].backward(retain_graph=True)
257
+
258
+ with torch.no_grad():
259
+ patch_grads = gradients.get('value')
260
+ patch_acts = activations.get('value')
261
+
262
+ weights = torch.mean(patch_grads, dim=1).squeeze(0)
263
+
264
+ cam_1d = torch.einsum('pe,e->p', patch_acts.squeeze(0), weights)
265
+ cam_1d = torch.relu(cam_1d)
266
+
267
+ cam = cam_1d.reshape(27, 27).detach().cpu().numpy()
268
+
269
+ handle_forward.remove()
270
+ handle_backward.remove()
271
+
272
+ return create_cam_visualization_pil(img, cam, alpha=alpha, vis_threshold=threshold), cam
273
+
274
+ def create_cam_visualization_pil(image_pil, cam, alpha=0.6, vis_threshold=0.2):
275
+ """
276
+ Overlays CAM on image and returns a PIL image.
277
+
278
+ Args:
279
+ image_pil: PIL Image (RGB)
280
+ cam: 2D numpy array (activation map)
281
+ alpha: float, blending factor
282
+ vis_threshold: float, minimum normalized CAM value to show color
283
+
284
+ Returns:
285
+ PIL.Image.Image with overlay
286
+ """
287
+ if cam is None:
288
+ return image_pil
289
+ w, h = image_pil.size
290
+ size = max(w, h)
291
+
292
+ # Normalize CAM to [0, 1]
293
+ cam -= cam.min()
294
+ cam /= cam.max()
295
+
296
+ # Create heatmap using matplotlib colormap
297
+ colormap = cm.get_cmap('inferno')
298
+ cam_rgb = colormap(cam)[:, :, :3] # RGB
299
+
300
+ # Create alpha channel
301
+ cam_alpha = (cam >= vis_threshold).astype(np.float32) * alpha # Alpha mask
302
+ cam_rgba = np.dstack((cam_rgb, cam_alpha)) # Shape: (H, W, 4)
303
+
304
+ # Coarse upscale for CAM output -- keeps "blocky" effect that is truer to what is measured
305
+ cam_pil = Image.fromarray((cam_rgba * 255).astype(np.uint8), mode="RGBA")
306
+ cam_pil = cam_pil.resize((216,216), resample=Image.Resampling.NEAREST)
307
+
308
+ # Model uses padded image as input, this matches attention map to input image aspect ratio
309
+ cam_pil = cam_pil.resize((size, size), resample=Image.Resampling.BICUBIC)
310
+ cam_pil = transforms.CenterCrop((h, w))(cam_pil)
311
+
312
+ # Composite over original
313
+ composite = Image.alpha_composite(image_pil, cam_pil)
314
+
315
+ return composite
316
 
317
  class ImageDataset(Dataset):
318
  def __init__(self, image_files, transform):
 
394
 
395
  return temp_file.name, df
396
 
397
+ custom_css = """
398
+ .output-class { display: none; }
399
+ .inferno-slider input[type=range] {
400
+ background: linear-gradient(to right,
401
+ #000004, #1b0c41, #4a0c6b, #781c6d,
402
+ #a52c60, #cf4446, #ed6925, #fb9b06,
403
+ #f7d13d, #fcffa4
404
+ ) !important;
405
+ background-size: 100% 100% !important;
406
+ }
407
+ #image_container-image {
408
+ width: 100%;
409
+ aspect-ratio: 1 / 1;
410
+ max-height: 100%;
411
+ }
412
+ #image_container img {
413
+ object-fit: contain !important;
414
+ }
415
+ """
416
+
417
+ with gr.Blocks(css=custom_css) as demo:
418
  gr.Markdown("""
419
  ## Joint Tagger Project: JTP-PILOT² Demo **BETA**
 
 
 
 
 
420
  """)
421
  with gr.Tabs():
422
  with gr.TabItem("Single Image"):
423
+ original_image_state = gr.State() # stash a copy of the input image
424
+ sorted_tag_score_state = gr.State(value={}) # stash a copy of the input image
425
+ cam_state = gr.State()
426
  with gr.Row():
427
  with gr.Column():
428
+ image = gr.Image(label="Source", sources=['upload', 'clipboard'], type='pil', show_label=False, elem_id="image_container")
429
+ cam_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.40, label="CAM Threshold", elem_classes="inferno-slider")
430
+ alpha_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.60, label="CAM Alpha")
431
  with gr.Column():
432
  tag_string = gr.Textbox(label="Tag String")
433
+ threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Tag Threshold")
434
  label_box = gr.Label(label="Tag Predictions", num_top_classes=250, show_label=False)
435
 
436
+ image.upload(
437
  fn=run_classifier,
438
+ inputs=[image, threshold_slider],
439
+ outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state],
440
+ show_progress='minimal'
441
+ )
442
+
443
+ image.clear(
444
+ fn=clear_image,
445
+ inputs=[],
446
+ outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state, cam_state]
447
  )
448
 
449
  threshold_slider.input(
450
  fn=create_tags,
451
+ inputs=[threshold_slider, sorted_tag_score_state],
452
+ outputs=[tag_string, label_box],
453
+ show_progress='hidden'
454
+ )
455
+
456
+ label_box.select(
457
+ fn=cam_inference,
458
+ inputs=[original_image_state, cam_slider, alpha_slider],
459
+ outputs=[image, cam_state],
460
+ show_progress='minimal'
461
+ )
462
+
463
+ cam_slider.input(
464
+ fn=create_cam_visualization_pil,
465
+ inputs=[original_image_state, cam_state, alpha_slider, cam_slider],
466
+ outputs=[image],
467
+ show_progress='hidden'
468
+ )
469
+
470
+ alpha_slider.input(
471
+ fn=create_cam_visualization_pil,
472
+ inputs=[original_image_state, cam_state, alpha_slider, cam_slider],
473
+ outputs=[image],
474
+ show_progress='hidden'
475
  )
476
 
477
  with gr.TabItem("Multiple Images"):
 
489
  inputs=[zip_input, multi_threshold_slider],
490
  outputs=[zip_output, dataframe_output]
491
  )
492
+ gr.Markdown("""
493
+ This tagger is designed for use on furry images (though may very well work on out-of-distribution images, potentially with funny results). A threshold of 0.2 is recommended. Lower thresholds often turn up more valid tags, but can also result in some amount of hallucinated tags.
494
+ This tagger is the result of joint efforts between members of the RedRocket team, with distinctions given to Thessalo for creating the foundation for this project with his efforts, RedHotTensors for redesigning the process into a second-order method that models information expectation, and drhead for dataset prep, creation of training code and supervision of training runs.
495
+ Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
496
+ """)
497
 
498
  if __name__ == "__main__":
499
  demo.launch()
requirements.txt CHANGED
@@ -3,4 +3,6 @@ torchvision
3
  timm
4
  pillow
5
  safetensors
6
- rarfile
 
 
 
3
  timm
4
  pillow
5
  safetensors
6
+ rarfile
7
+ pydantic==2.10.6
8
+ msgspec