VisionLanguageGroup commited on
Commit
86072ea
·
1 Parent(s): 4ce5a27
_utils/attn_utils_new.py CHANGED
@@ -37,12 +37,6 @@ class CountingCrossAttnProcessor1:
37
  context = encoder_hidden_states if is_cross else hidden_states
38
  k = attn_layer.to_k(context)
39
  v = attn_layer.to_v(context)
40
- # q = attn_layer.reshape_heads_to_batch_dim(q)
41
- # k = attn_layer.reshape_heads_to_batch_dim(k)
42
- # v = attn_layer.reshape_heads_to_batch_dim(v)
43
- # q = attn_layer.head_to_batch_dim(q)
44
- # k = attn_layer.head_to_batch_dim(k)
45
- # v = attn_layer.head_to_batch_dim(v)
46
  q = self.head_to_batch_dim(q, h)
47
  k = self.head_to_batch_dim(k, h)
48
  v = self.head_to_batch_dim(v, h)
@@ -57,11 +51,8 @@ class CountingCrossAttnProcessor1:
57
 
58
  # attention, what we cannot get enough of
59
  attn_ = sim.softmax(dim=-1).clone()
60
- # softmax = nn.Softmax(dim=-1)
61
- # attn_ = softmax(sim)
62
  self.attnstore(attn_, is_cross, self.place_in_unet)
63
  out = torch.einsum("b i j, b j d -> b i d", attn_, v)
64
- # out = attn_layer.batch_to_head_dim(out)
65
  out = self.batch_to_head_dim(out, h)
66
 
67
  if type(attn_layer.to_out) is torch.nn.modules.container.ModuleList:
@@ -112,9 +103,6 @@ def register_attention_control(model, controller):
112
  continue
113
 
114
  cross_att_count += 1
115
- # attn_procs[name] = AttendExciteCrossAttnProcessor(
116
- # attnstore=controller, place_in_unet=place_in_unet
117
- # )
118
  attn_procs[name] = CountingCrossAttnProcessor1(
119
  attnstore=controller, place_in_unet=place_in_unet
120
  )
 
37
  context = encoder_hidden_states if is_cross else hidden_states
38
  k = attn_layer.to_k(context)
39
  v = attn_layer.to_v(context)
 
 
 
 
 
 
40
  q = self.head_to_batch_dim(q, h)
41
  k = self.head_to_batch_dim(k, h)
42
  v = self.head_to_batch_dim(v, h)
 
51
 
52
  # attention, what we cannot get enough of
53
  attn_ = sim.softmax(dim=-1).clone()
 
 
54
  self.attnstore(attn_, is_cross, self.place_in_unet)
55
  out = torch.einsum("b i j, b j d -> b i d", attn_, v)
 
56
  out = self.batch_to_head_dim(out, h)
57
 
58
  if type(attn_layer.to_out) is torch.nn.modules.container.ModuleList:
 
103
  continue
104
 
105
  cross_att_count += 1
 
 
 
106
  attn_procs[name] = CountingCrossAttnProcessor1(
107
  attnstore=controller, place_in_unet=place_in_unet
108
  )
_utils/load_track_data.py CHANGED
@@ -49,9 +49,7 @@ def _load_tiffs(folder: Path, dtype=None):
49
 
50
  def load_track_images(file_dir):
51
 
52
- # suffix_ = [".png", ".tif", ".tiff", ".jpg"]
53
  def find_tif_dir(root_dir):
54
- """递归查找.tif 文件"""
55
  tif_files = []
56
  for dirpath, _, filenames in os.walk(root_dir):
57
  if '__MACOSX' in dirpath:
@@ -112,7 +110,3 @@ def load_track_images(file_dir):
112
 
113
  return imgs, imgs_raw, images_stable, imgs_, imgs_01, height, width
114
 
115
- if __name__ == "__main__":
116
- file_dir = "data/2D+Time/DIC-C2DH-HeLa/train/DIC-C2DH-HeLa/02"
117
- imgs, imgs_raw, images_stable, imgs_, imgs_01, height, width = load_track_images(file_dir)
118
- print(imgs.shape, imgs_raw.shape, images_stable.shape, imgs_.shape, imgs_01.shape, height, width)
 
49
 
50
  def load_track_images(file_dir):
51
 
 
52
  def find_tif_dir(root_dir):
 
53
  tif_files = []
54
  for dirpath, _, filenames in os.walk(root_dir):
55
  if '__MACOSX' in dirpath:
 
110
 
111
  return imgs, imgs_raw, images_stable, imgs_, imgs_01, height, width
112
 
 
 
 
 
_utils/track_args.py DELETED
@@ -1,62 +0,0 @@
1
- import configargparse
2
-
3
-
4
- def parse_train_args():
5
- parser = configargparse.ArgumentParser(
6
- formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
7
- config_file_parser_class=configargparse.YAMLConfigFileParser,
8
- allow_abbrev=False,
9
- )
10
- parser.add_argument(
11
- "-c",
12
- "--config",
13
- default="_utils/example_config.yaml",
14
- is_config_file=True,
15
- help="config file path",
16
- )
17
- parser.add_argument("-d", "--d_model", type=int, default=256)
18
- parser.add_argument("-w", "--window", type=int, default=10)
19
- parser.add_argument("--spatial_pos_cutoff", type=int, default=256)
20
- parser.add_argument("--num_encoder_layers", type=int, default=6)
21
- parser.add_argument("--num_decoder_layers", type=int, default=6)
22
- parser.add_argument("--pos_embed_per_dim", type=int, default=32)
23
- parser.add_argument("--feat_embed_per_dim", type=int, default=8)
24
- parser.add_argument("--dropout", type=float, default=0.00)
25
- parser.add_argument(
26
- "--attn_positional_bias",
27
- type=str,
28
- choices=["rope", "bias", "none"],
29
- default="rope",
30
- )
31
- parser.add_argument("--attn_positional_bias_n_spatial", type=int, default=16)
32
- parser.add_argument("--attn_dist_mode", default="v0")
33
- parser.add_argument(
34
- "--causal_norm",
35
- type=str,
36
- choices=["none", "linear", "softmax", "quiet_softmax"],
37
- default="quiet_softmax",
38
- )
39
-
40
- args, unknown_args = parser.parse_known_args()
41
-
42
- # # Hack to allow for --input_test
43
- # allowed_unknown = ["input_test"]
44
- # if not set(a.split("=")[0].strip("-") for a in unknown_args).issubset(
45
- # set(allowed_unknown)
46
- # ):
47
- # raise ValueError(f"Unknown args: {unknown_args}")
48
-
49
- # pprint(vars(args))
50
-
51
- # for backward compatibility
52
- # if args.attn_positional_bias == "True":
53
- # args.attn_positional_bias = "bias"
54
- # elif args.attn_positional_bias == "False":
55
- # args.attn_positional_bias = False
56
-
57
- # if args.train_samples == 0:
58
- # raise NotImplementedError(
59
- # "--train_samples must be > 0, full dataset pass not supported."
60
- # )
61
-
62
- return args
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -937,12 +937,21 @@ with gr.Blocks(
937
  ) as demo:
938
  gr.Markdown(
939
  """
940
- # 🔬 Microscopy Image Analysis Suite
941
 
942
- Supporting three key tasks:
943
  - 🎨 **Segmentation**: Instance segmentation of microscopic objects
944
  - 🔢 **Counting**: Counting microscopic objects based on density maps
945
  - 🎬 **Tracking**: Tracking microscopic objects in video sequences
 
 
 
 
 
 
 
 
 
946
  """
947
  )
948
 
@@ -1667,26 +1676,12 @@ with gr.Blocks(
1667
  outputs=[feedback_status, feedback_status]
1668
  )
1669
 
1670
- gr.Markdown(
1671
- """
1672
- ---
1673
- ### 📒 Note:
1674
-
1675
- This project is currently available with usage limits for research trial use and feedback collection. We plan to release a free public version in the future. We are actively improving the toolkit and greatly appreciate your feedback!
1676
-
1677
-
1678
-
1679
- ### 💡 Technical Details
1680
-
1681
- **MicroscopyMatching** - A general-purpose microscopy image analysis toolkit based on Stable Diffusion
1682
-
1683
- """
1684
- )
1685
 
1686
  if __name__ == "__main__":
1687
  demo.queue().launch(
1688
  server_name="0.0.0.0",
1689
- server_port=7860,
1690
  share=False,
1691
  ssr_mode=False,
1692
  show_error=True,
 
937
  ) as demo:
938
  gr.Markdown(
939
  """
940
+ # 🔬 MicroscopyMatching: Microscopy Image Analysis Suite
941
 
942
+ ### Supporting three key tasks:
943
  - 🎨 **Segmentation**: Instance segmentation of microscopic objects
944
  - 🔢 **Counting**: Counting microscopic objects based on density maps
945
  - 🎬 **Tracking**: Tracking microscopic objects in video sequences
946
+
947
+ ### 💡 Technical Details:
948
+
949
+ **MicroscopyMatching** - A general-purpose microscopy image analysis toolkit based on Stable Diffusion
950
+
951
+ ### 📒 Note:
952
+
953
+ This project is currently available with usage limits for research trial use and feedback collection. We plan to release a free public version in the future. We are actively improving the toolkit and greatly appreciate your feedback!
954
+
955
  """
956
  )
957
 
 
1676
  outputs=[feedback_status, feedback_status]
1677
  )
1678
 
1679
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1680
 
1681
  if __name__ == "__main__":
1682
  demo.queue().launch(
1683
  server_name="0.0.0.0",
1684
+ server_port=7861,
1685
  share=False,
1686
  ssr_mode=False,
1687
  show_error=True,
counting.py CHANGED
@@ -153,9 +153,9 @@ class CountingModule(pl.LightningModule):
153
  loca_feature_bf_regression = loca_out["feature_bf_regression"]
154
  adapted_emb = self.counting_adapter.adapter(loca_feature_bf_regression, boxes) # shape [1, 768]
155
  if task_loc_idx.shape[0] == 0:
156
- encoder_hidden_states[0,2,:] = adapted_emb.squeeze() # 放在task prompt下一位
157
  else:
158
- encoder_hidden_states[0,task_loc_idx[0, 1]+1,:] = adapted_emb.squeeze() # 放在task prompt下一位
159
 
160
  # Predict the noise residual
161
  noise_pred, feature_list = self.stable.unet(noisy_latents, timesteps, encoder_hidden_states)
@@ -174,7 +174,7 @@ class CountingModule(pl.LightningModule):
174
 
175
  # only use 64x64 self-attention
176
  self_attn_aggregate = attn_utils.aggregate_attention( # [res, res, 4096]
177
- prompts=[self.config.prompt], # 这里要改么
178
  attention_store=self.controller,
179
  res=64,
180
  from_where=("up", "down"),
@@ -182,7 +182,7 @@ class CountingModule(pl.LightningModule):
182
  select=0
183
  )
184
  self_attn_aggregate32 = attn_utils.aggregate_attention( # [res, res, 4096]
185
- prompts=[self.config.prompt], # 这里要改么
186
  attention_store=self.controller,
187
  res=32,
188
  from_where=("up", "down"),
@@ -190,7 +190,7 @@ class CountingModule(pl.LightningModule):
190
  select=0
191
  )
192
  self_attn_aggregate16 = attn_utils.aggregate_attention( # [res, res, 4096]
193
- prompts=[self.config.prompt], # 这里要改么
194
  attention_store=self.controller,
195
  res=16,
196
  from_where=("up", "down"),
@@ -201,7 +201,7 @@ class CountingModule(pl.LightningModule):
201
  # cross attention
202
  for res in [32, 16]:
203
  attn_aggregate = attn_utils.aggregate_attention( # [res, res, 77]
204
- prompts=[self.config.prompt], # 这里要改么
205
  attention_store=self.controller,
206
  res=res,
207
  from_where=("up", "down"),
@@ -212,7 +212,7 @@ class CountingModule(pl.LightningModule):
212
  task_attn_ = attn_aggregate[:, :, 1].unsqueeze(0).unsqueeze(0) # [1, 1, res, res]
213
  attention_maps.append(task_attn_)
214
  if self.use_box:
215
- exemplar_attns = attn_aggregate[:, :, 2].unsqueeze(0).unsqueeze(0) # 取exemplar的attn
216
  exemplar_attention_maps.append(exemplar_attns)
217
  else:
218
  exemplar_attns1 = attn_aggregate[:, :, 2].unsqueeze(0).unsqueeze(0)
@@ -266,10 +266,6 @@ class CountingModule(pl.LightningModule):
266
  attn_stack = torch.cat(attn_stack, dim=1)
267
 
268
  if not self.use_box:
269
-
270
- # cross_self_exe_attn_np = cross_self_exe_attn.detach().squeeze().cpu().numpy()
271
- # boxes = gen_dummy_boxes(cross_self_exe_attn_np, max_boxes=1)
272
- # boxes = boxes.to(self.device)
273
 
274
  loca_out = self.loca_model.forward_before_reg(input_image, boxes)
275
  loca_feature_bf_regression = loca_out["feature_bf_regression"]
 
153
  loca_feature_bf_regression = loca_out["feature_bf_regression"]
154
  adapted_emb = self.counting_adapter.adapter(loca_feature_bf_regression, boxes) # shape [1, 768]
155
  if task_loc_idx.shape[0] == 0:
156
+ encoder_hidden_states[0,2,:] = adapted_emb.squeeze()
157
  else:
158
+ encoder_hidden_states[0,task_loc_idx[0, 1]+1,:] = adapted_emb.squeeze()
159
 
160
  # Predict the noise residual
161
  noise_pred, feature_list = self.stable.unet(noisy_latents, timesteps, encoder_hidden_states)
 
174
 
175
  # only use 64x64 self-attention
176
  self_attn_aggregate = attn_utils.aggregate_attention( # [res, res, 4096]
177
+ prompts=[self.config.prompt],
178
  attention_store=self.controller,
179
  res=64,
180
  from_where=("up", "down"),
 
182
  select=0
183
  )
184
  self_attn_aggregate32 = attn_utils.aggregate_attention( # [res, res, 4096]
185
+ prompts=[self.config.prompt],
186
  attention_store=self.controller,
187
  res=32,
188
  from_where=("up", "down"),
 
190
  select=0
191
  )
192
  self_attn_aggregate16 = attn_utils.aggregate_attention( # [res, res, 4096]
193
+ prompts=[self.config.prompt],
194
  attention_store=self.controller,
195
  res=16,
196
  from_where=("up", "down"),
 
201
  # cross attention
202
  for res in [32, 16]:
203
  attn_aggregate = attn_utils.aggregate_attention( # [res, res, 77]
204
+ prompts=[self.config.prompt],
205
  attention_store=self.controller,
206
  res=res,
207
  from_where=("up", "down"),
 
212
  task_attn_ = attn_aggregate[:, :, 1].unsqueeze(0).unsqueeze(0) # [1, 1, res, res]
213
  attention_maps.append(task_attn_)
214
  if self.use_box:
215
+ exemplar_attns = attn_aggregate[:, :, 2].unsqueeze(0).unsqueeze(0)
216
  exemplar_attention_maps.append(exemplar_attns)
217
  else:
218
  exemplar_attns1 = attn_aggregate[:, :, 2].unsqueeze(0).unsqueeze(0)
 
266
  attn_stack = torch.cat(attn_stack, dim=1)
267
 
268
  if not self.use_box:
 
 
 
 
269
 
270
  loca_out = self.loca_model.forward_before_reg(input_image, boxes)
271
  loca_feature_bf_regression = loca_out["feature_bf_regression"]
models/seg_post_model/models.py CHANGED
@@ -16,7 +16,7 @@ import logging
16
  models_logger = logging.getLogger(__name__)
17
 
18
  from . import transforms, dynamics, utils
19
- from .vit_sam import Transformer
20
  from .core import assign_device, run_net
21
 
22
  # _MODEL_DIR_ENV = os.environ.get("CELLPOSE_LOCAL_MODELS_PATH")
 
16
  models_logger = logging.getLogger(__name__)
17
 
18
  from . import transforms, dynamics, utils
19
+ from .vit import Transformer
20
  from .core import assign_device, run_net
21
 
22
  # _MODEL_DIR_ENV = os.environ.get("CELLPOSE_LOCAL_MODELS_PATH")
models/seg_post_model/{vit_sam.py → vit.py} RENAMED
File without changes
models/tra_post_model/data.py CHANGED
@@ -1,5 +1,6 @@
1
  """Regionprops features and its augmentations.
2
  WindowedRegionFeatures (WRFeatures) is a class that holds regionprops features for a windowed track region.
 
3
  """
4
 
5
  import itertools
 
1
  """Regionprops features and its augmentations.
2
  WindowedRegionFeatures (WRFeatures) is a class that holds regionprops features for a windowed track region.
3
+ Modified from Trackastra (https://github.com/weigertlab/trackastra)
4
  """
5
 
6
  import itertools
models/tra_post_model/model.py CHANGED
@@ -599,47 +599,14 @@ class TrackingTransformer(torch.nn.Module):
599
 
600
  @classmethod
601
  def from_folder(
602
- cls, folder, map_location=None, args=None, checkpoint_path: str = "model.pt"
603
  ):
604
  folder = Path(folder)
605
 
606
  config = yaml.load(open(folder / "config.yaml"), Loader=yaml.FullLoader)
607
- if args:
608
- args = vars(args)
609
- for k, v in config.items():
610
- errors = []
611
- if k in args:
612
- if config[k] != args[k]:
613
- errors.append(
614
- f"Loaded model config {k}={config[k]}, but current argument"
615
- f" {k}={args[k]}."
616
- )
617
- if errors:
618
- raise ValueError("\n".join(errors))
619
 
620
  model = cls(**config)
621
 
622
- # try:
623
- # # Try to load from lightning checkpoint first
624
- # v_folder = sorted((folder / "tb").glob("version_*"))[version]
625
- # checkpoint = sorted((v_folder / "checkpoints").glob("*epoch*.ckpt"))[0]
626
- # pl_state_dict = torch.load(checkpoint, map_location=map_location)[
627
- # "state_dict"
628
- # ]
629
- # state_dict = OrderedDict()
630
-
631
- # # Hack
632
- # for k, v in pl_state_dict.items():
633
- # if k.startswith("model."):
634
- # state_dict[k[6:]] = v
635
- # else:
636
- # raise ValueError(f"Unexpected key {k} in state_dict")
637
-
638
- # model.load_state_dict(state_dict)
639
- # logger.info(f"Loaded model from {checkpoint}")
640
- # except:
641
- # # Default: Load manually saved model (legacy)
642
-
643
  fpath = folder / checkpoint_path
644
  logger.info(f"Loading model state from {fpath}")
645
 
@@ -656,24 +623,12 @@ class TrackingTransformer(torch.nn.Module):
656
 
657
  @classmethod
658
  def from_cfg(
659
- cls, cfg_path, args=None
660
  ):
661
 
662
  cfg_path = Path(cfg_path)
663
 
664
  config = yaml.load(open(cfg_path), Loader=yaml.FullLoader)
665
- if args:
666
- args = vars(args)
667
- for k, v in config.items():
668
- errors = []
669
- if k in args:
670
- if config[k] != args[k]:
671
- errors.append(
672
- f"Loaded model config {k}={config[k]}, but current argument"
673
- f" {k}={args[k]}."
674
- )
675
- if errors:
676
- raise ValueError("\n".join(errors))
677
 
678
  model = cls(**config)
679
 
 
599
 
600
  @classmethod
601
  def from_folder(
602
+ cls, folder, map_location=None, checkpoint_path: str = "model.pt"
603
  ):
604
  folder = Path(folder)
605
 
606
  config = yaml.load(open(folder / "config.yaml"), Loader=yaml.FullLoader)
 
 
 
 
 
 
 
 
 
 
 
 
607
 
608
  model = cls(**config)
609
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
610
  fpath = folder / checkpoint_path
611
  logger.info(f"Loading model state from {fpath}")
612
 
 
623
 
624
  @classmethod
625
  def from_cfg(
626
+ cls, cfg_path
627
  ):
628
 
629
  cfg_path = Path(cfg_path)
630
 
631
  config = yaml.load(open(cfg_path), Loader=yaml.FullLoader)
 
 
 
 
 
 
 
 
 
 
 
 
632
 
633
  model = cls(**config)
634
 
models/tra_post_model/tracking/__init__.py CHANGED
@@ -1,5 +1,3 @@
1
- # ruff: noqa: F401
2
-
3
  from .track_graph import TrackGraph
4
  from .tracking import (
5
  build_graph,
 
 
 
1
  from .track_graph import TrackGraph
2
  from .tracking import (
3
  build_graph,
models/tra_post_model/tracking/ilp.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import logging
2
  import time
3
  from types import SimpleNamespace
 
1
+ # Modified from Trackastra (https://github.com/weigertlab/trackastra)
2
+
3
  import logging
4
  import time
5
  from types import SimpleNamespace
models/tra_post_model/tracking/tracking.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import logging
2
  from itertools import chain
3
 
 
1
+ # Modified from Trackastra (https://github.com/weigertlab/trackastra)
2
+
3
  import logging
4
  from itertools import chain
5
 
models/tra_post_model/tracking/utils.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import logging
2
  from collections import deque
3
  from pathlib import Path
 
1
+ # Modified from Trackastra (https://github.com/weigertlab/trackastra)
2
+
3
  import logging
4
  from collections import deque
5
  from pathlib import Path
models/tra_post_model/utils.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import logging
2
 
3
  import dask.array as da
@@ -41,7 +43,6 @@ def blockwise_sum(
41
  return B
42
 
43
 
44
- # TODO allow for batch dimension. Should be faster than looping
45
  def blockwise_causal_norm(
46
  A: torch.Tensor,
47
  timepoints: torch.Tensor,
@@ -70,7 +71,7 @@ def blockwise_causal_norm(
70
  if mode in ("softmax", "quiet_softmax"):
71
  # Subtract max for numerical stability
72
  # https://stats.stackexchange.com/questions/338285/how-does-the-subtraction-of-the-logit-maximum-improve-learning
73
- # TODO test without this subtraction
74
 
75
  if mask_invalid is not None:
76
  assert mask_invalid.shape == A.shape
 
1
+ # Modified from Trackastra (https://github.com/weigertlab/trackastra)
2
+
3
  import logging
4
 
5
  import dask.array as da
 
43
  return B
44
 
45
 
 
46
  def blockwise_causal_norm(
47
  A: torch.Tensor,
48
  timepoints: torch.Tensor,
 
71
  if mode in ("softmax", "quiet_softmax"):
72
  # Subtract max for numerical stability
73
  # https://stats.stackexchange.com/questions/338285/how-does-the-subtraction-of-the-logit-maximum-improve-learning
74
+
75
 
76
  if mask_invalid is not None:
77
  assert mask_invalid.shape == A.shape
tracking_one.py CHANGED
@@ -1,16 +1,11 @@
1
  import os
2
- import pprint
3
  from typing import Any, List, Optional
4
- import argparse
5
  from huggingface_hub import hf_hub_download
6
- import pyrallis
7
  from pytorch_lightning.utilities.types import STEP_OUTPUT
8
  import torch
9
- import os
10
  from PIL import Image
11
  import numpy as np
12
  import tifffile
13
- import skimage.io as io
14
  from config import RunConfig
15
  from _utils import attn_utils_new as attn_utils
16
  from _utils.attn_utils_new import AttentionStore
@@ -18,7 +13,6 @@ from _utils.misc_helper import *
18
  import torch.nn.functional as F
19
  from tqdm import tqdm
20
  import torch.nn as nn
21
- import matplotlib.pyplot as plt
22
  import cv2
23
  import warnings
24
  warnings.filterwarnings("ignore", category=UserWarning)
@@ -33,7 +27,6 @@ from models.tra_post_model.utils import (
33
  )
34
  from models.tra_post_model.data import build_windows_sd, get_features
35
  from models.tra_post_model.tracking import TrackGraph, build_graph, track_greedy
36
- from _utils.track_args import parse_train_args as get_track_args
37
  import torchvision.transforms as T
38
  from pathlib import Path
39
  import dask.array as da
@@ -41,7 +34,6 @@ from typing import Dict, List, Optional, Union, Literal
41
  from scipy.sparse import SparseEfficiencyWarning, csr_array
42
  import tracemalloc
43
  import gc
44
- # from memory_profiler import profile
45
  from _utils.load_track_data import load_track_images
46
 
47
  SCALE = 1
@@ -82,15 +74,8 @@ class TrackingModule(pl.LightningModule):
82
 
83
  # load loca model
84
  self.loca_model = build_loca_model()
85
- # weights = torch.load("ckpt/loca_few_shot.pt")["model"]
86
- # weights = {k.replace("module","") : v for k, v in weights.items()}
87
- # self.loca_model.load_state_dict(weights, strict=False)
88
- # del weights
89
 
90
  self.counting_adapter = Counting(scale_factor=SCALE)
91
- # if os.path.isfile(self.args.adapter_weight):
92
- # adapter_weight = torch.load(self.args.adapter_weight,map_location=torch.device('cpu'))
93
- # self.counting_adapter.load_state_dict(adapter_weight, strict=False)
94
 
95
  ### load stable diffusion and its controller
96
  self.stable = load_stable_diffusion_model(config=self.config)
@@ -110,7 +95,6 @@ class TrackingModule(pl.LightningModule):
110
  " `placeholder_token` that is not already in the tokenizer."
111
  )
112
  try:
113
- # task_embed_from_pretrain = torch.load("pretrained/task_embed.pth")
114
  task_embed_from_pretrain = hf_hub_download(
115
  repo_id="phoebe777777/111",
116
  filename="task_embed.pth",
@@ -144,30 +128,17 @@ class TrackingModule(pl.LightningModule):
144
  self.placeholder_token_id = placeholder_token_id
145
 
146
  fpath = Path("_utils/config.yaml")
147
- args_ = get_track_args()
148
 
149
  model = TrackingTransformer.from_cfg(
150
  cfg_path=fpath,
151
- args=args_,
152
  )
153
- # model = TrackingTransformer.from_folder(
154
- # Path(*fpath.parts[:-1]),
155
- # args=args_,
156
- # checkpoint_path=Path(*fpath.parts[-1:]),
157
- # )
158
-
159
 
160
  self.track_model = model
161
- self.track_args = args_
162
 
163
 
164
  def move_to_device(self, device):
165
  self.stable.to(device)
166
- # if self.loca_model is not None and self.counting_adapter is not None:
167
- # self.loca_model.to(device)
168
- # self.counting_adapter.to(device)
169
  self.counting_adapter.to(device)
170
- # self.dino.to(device)
171
  self.loca_model.to(device)
172
  self.track_model.to(device)
173
 
@@ -221,9 +192,9 @@ class TrackingModule(pl.LightningModule):
221
  adapted_emb = self.counting_adapter.adapter(loca_feature_bf_regression, boxes) # shape [1, 768]
222
 
223
  if task_loc_idx.shape[0] == 0:
224
- encoder_hidden_states[0,2,:] = adapted_emb.squeeze() # 放在task prompt下一位
225
  else:
226
- encoder_hidden_states[:,task_loc_idx[0, 1]+1,:] = adapted_emb.squeeze() # 放在task prompt下一位
227
 
228
  # Predict the noise residual
229
  noise_pred, feature_list = self.stable.unet(noisy_latents, timesteps, encoder_hidden_states)
@@ -242,7 +213,7 @@ class TrackingModule(pl.LightningModule):
242
 
243
  # only use 64x64 self-attention
244
  self_attn_aggregate = attn_utils.aggregate_attention( # [res, res, 4096]
245
- prompts=[self.config.prompt for i in range(bsz)], # 这里要改么
246
  attention_store=self.controller,
247
  res=64,
248
  from_where=("up", "down"),
@@ -250,7 +221,7 @@ class TrackingModule(pl.LightningModule):
250
  select=0
251
  )
252
  self_attn_aggregate32 = attn_utils.aggregate_attention( # [res, res, 4096]
253
- prompts=[self.config.prompt for i in range(bsz)], # 这里要改么
254
  attention_store=self.controller,
255
  res=32,
256
  from_where=("up", "down"),
@@ -258,7 +229,7 @@ class TrackingModule(pl.LightningModule):
258
  select=0
259
  )
260
  self_attn_aggregate16 = attn_utils.aggregate_attention( # [res, res, 4096]
261
- prompts=[self.config.prompt for i in range(bsz)], # 这里要改么
262
  attention_store=self.controller,
263
  res=16,
264
  from_where=("up", "down"),
@@ -269,7 +240,7 @@ class TrackingModule(pl.LightningModule):
269
  # cross attention
270
  for res in [32, 16]:
271
  attn_aggregate = attn_utils.aggregate_attention( # [res, res, 77]
272
- prompts=[self.config.prompt for i in range(bsz)], # 这里要改么
273
  attention_store=self.controller,
274
  res=res,
275
  from_where=("up", "down"),
@@ -279,7 +250,7 @@ class TrackingModule(pl.LightningModule):
279
 
280
  task_attn_ = attn_aggregate[:, :, 1].unsqueeze(0).unsqueeze(0) # [1, 1, res, res]
281
  attention_maps.append(task_attn_)
282
- exemplar_attns = attn_aggregate[:, :, 2].unsqueeze(0).unsqueeze(0) # 取exemplar的attn
283
  exemplar_attention_maps.append(exemplar_attns)
284
 
285
 
@@ -306,7 +277,7 @@ class TrackingModule(pl.LightningModule):
306
  attn_stack = torch.cat(attn_stack, dim=1)
307
 
308
 
309
- attn_after_new_regressor, loss = self.counting_adapter.regressor(input_image, attn_stack, feature_list, mask.cpu().numpy(), training=False) # 直接用自己的
310
 
311
  return {
312
  "attn_after_new_regressor":attn_after_new_regressor,
@@ -364,9 +335,9 @@ class TrackingModule(pl.LightningModule):
364
  adapted_emb = self.adapt_emb.to(self.device)
365
  task_loc_idx = torch.nonzero(input_ids == self.placeholder_token_id)
366
  if task_loc_idx.shape[0] == 0:
367
- encoder_hidden_states[0,5,:] = adapted_emb.squeeze() # 放在task prompt下一位
368
  else:
369
- encoder_hidden_states[:,task_loc_idx[0, 1]+4,:] = adapted_emb.squeeze() # 放在task prompt下一位
370
 
371
  # Predict the noise residual
372
  noise_pred, feature_list = self.stable.unet(noisy_latents, timesteps, encoder_hidden_states)
@@ -386,7 +357,7 @@ class TrackingModule(pl.LightningModule):
386
 
387
  # only use 64x64 self-attention
388
  self_attn_aggregate = attn_utils.aggregate_attention( # [res, res, 4096]
389
- prompts=[self.config.prompt for i in range(bsz)], # 这里要改么
390
  attention_store=self.controller,
391
  res=64,
392
  from_where=("up", "down"),
@@ -397,7 +368,7 @@ class TrackingModule(pl.LightningModule):
397
  # cross attention
398
  for res in [32, 16]:
399
  attn_aggregate = attn_utils.aggregate_attention( # [res, res, 77]
400
- prompts=[self.config.prompt for i in range(bsz)], # 这里要改么
401
  attention_store=self.controller,
402
  res=res,
403
  from_where=("up", "down"),
@@ -408,13 +379,13 @@ class TrackingModule(pl.LightningModule):
408
  task_attn_ = attn_aggregate[:, :, 1].unsqueeze(0).unsqueeze(0) # [1, 1, res, res]
409
  attention_maps.append(task_attn_)
410
  # if self.boxes is not None and not self.training:
411
- exemplar_attns1 = attn_aggregate[:, :, 2].unsqueeze(0).unsqueeze(0) # 取exemplar的attn
412
  exemplar_attention_maps1.append(exemplar_attns1)
413
- exemplar_attns2 = attn_aggregate[:, :, 3].unsqueeze(0).unsqueeze(0) # 取exemplar的attn
414
  exemplar_attention_maps2.append(exemplar_attns2)
415
- exemplar_attns3 = attn_aggregate[:, :, 4].unsqueeze(0).unsqueeze(0) # 取exemplar的attn
416
  exemplar_attention_maps3.append(exemplar_attns3)
417
- exemplar_attns4 = attn_aggregate[:, :, 5].unsqueeze(0).unsqueeze(0) # 取exemplar的attn
418
  exemplar_attention_maps4.append(exemplar_attns4)
419
 
420
 
@@ -540,8 +511,7 @@ class TrackingModule(pl.LightningModule):
540
 
541
  for n in range(n_forward):
542
  len_ = min(74, n_instance - n * 74)
543
- encoder_hidden_states[:,(task_loc_idx[0, 1]+1):(task_loc_idx[0, 1]+1+len_),:] = adapted_emb[n*74:n*74+len_].squeeze() # 放在task prompt下一位
544
- # encoder_hidden_states: # [bsz, 77, 768], 其中第1位是task prompt的embedding, 第二位开始可以是object prompt的embedding, 最后一位应该保留原始embedding
545
 
546
 
547
  # Predict the noise residual
@@ -556,7 +526,7 @@ class TrackingModule(pl.LightningModule):
556
  # cross attention
557
  for res in [32, 16]:
558
  attn_aggregate = attn_utils.aggregate_attention( # [res, res, 77]
559
- prompts=[self.config.prompt for i in range(bsz)], # 这里要改么
560
  attention_store=self.controller,
561
  res=res,
562
  from_where=("up", "down"),
@@ -567,7 +537,7 @@ class TrackingModule(pl.LightningModule):
567
  task_attn_ = attn_aggregate[:, :, 1].unsqueeze(0).unsqueeze(0) # [1, 1, res, res]
568
  attention_maps.append(task_attn_)
569
  try:
570
- exemplar_attns = attn_aggregate[:, :, (task_loc_idx[0, 1]+1):(task_loc_idx[0, 1]+1+len_)].unsqueeze(0) # 取exemplar的attn
571
  except:
572
  print(n_instance, len_)
573
  exemplar_attns = torch.permute(exemplar_attns, (0, 3, 1, 2)) # [1, len_, res, res]
@@ -728,11 +698,6 @@ class TrackingModule(pl.LightningModule):
728
 
729
  A = self.track_model.normalize_output(A, timepoints, coords)
730
 
731
- # # Spatially far entries should not influence the causal normalization
732
- # dist = torch.cdist(coords[0, :, 1:], coords[0, :, 1:])
733
- # invalid = dist > model.config["spatial_pos_cutoff"]
734
- # A[invalid] = -torch.inf
735
-
736
  A = A.squeeze(0).detach().cpu().numpy()
737
 
738
  del feats, coords, timepoints, batch
@@ -1020,30 +985,3 @@ class TrackingModule(pl.LightningModule):
1020
  track_graph = self._track_from_predictions(predictions, mode=mode, **kwargs)
1021
 
1022
  return track_graph, masks
1023
-
1024
-
1025
-
1026
- # def inference(data_path, box=None):
1027
- # if box is not None:
1028
- # use_box = True
1029
- # else:
1030
- # use_box = False
1031
-
1032
- # model = TrackingModule(use_box=use_box)
1033
- # load_msg = model.load_state_dict(torch.load("pretrained/microscopy_matching_tra.pth"), strict=True)
1034
-
1035
- # model.move_to_device(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
1036
-
1037
-
1038
- # track_graph, masks = model.track(file_dir=data_path, dataname="inference_sequence")
1039
-
1040
- # if not os.path.exists(f"tracked_ours_seg_pred3/"):
1041
- # os.makedirs(f"tracked_ours_seg_pred3/")
1042
- # ctc_tracks, masks_tracked = graph_to_ctc(
1043
- # track_graph,
1044
- # masks,
1045
- # outdir=f"tracked_ours_seg_pred3/",
1046
- # )
1047
-
1048
- # if __name__ == "__main__":
1049
- # inference(data_path="example_imgs/2D+Time/Fluo-N2DL-HeLa/train/Fluo-N2DL-HeLa/02")
 
1
  import os
 
2
  from typing import Any, List, Optional
 
3
  from huggingface_hub import hf_hub_download
 
4
  from pytorch_lightning.utilities.types import STEP_OUTPUT
5
  import torch
 
6
  from PIL import Image
7
  import numpy as np
8
  import tifffile
 
9
  from config import RunConfig
10
  from _utils import attn_utils_new as attn_utils
11
  from _utils.attn_utils_new import AttentionStore
 
13
  import torch.nn.functional as F
14
  from tqdm import tqdm
15
  import torch.nn as nn
 
16
  import cv2
17
  import warnings
18
  warnings.filterwarnings("ignore", category=UserWarning)
 
27
  )
28
  from models.tra_post_model.data import build_windows_sd, get_features
29
  from models.tra_post_model.tracking import TrackGraph, build_graph, track_greedy
 
30
  import torchvision.transforms as T
31
  from pathlib import Path
32
  import dask.array as da
 
34
  from scipy.sparse import SparseEfficiencyWarning, csr_array
35
  import tracemalloc
36
  import gc
 
37
  from _utils.load_track_data import load_track_images
38
 
39
  SCALE = 1
 
74
 
75
  # load loca model
76
  self.loca_model = build_loca_model()
 
 
 
 
77
 
78
  self.counting_adapter = Counting(scale_factor=SCALE)
 
 
 
79
 
80
  ### load stable diffusion and its controller
81
  self.stable = load_stable_diffusion_model(config=self.config)
 
95
  " `placeholder_token` that is not already in the tokenizer."
96
  )
97
  try:
 
98
  task_embed_from_pretrain = hf_hub_download(
99
  repo_id="phoebe777777/111",
100
  filename="task_embed.pth",
 
128
  self.placeholder_token_id = placeholder_token_id
129
 
130
  fpath = Path("_utils/config.yaml")
 
131
 
132
  model = TrackingTransformer.from_cfg(
133
  cfg_path=fpath,
 
134
  )
 
 
 
 
 
 
135
 
136
  self.track_model = model
 
137
 
138
 
139
  def move_to_device(self, device):
140
  self.stable.to(device)
 
 
 
141
  self.counting_adapter.to(device)
 
142
  self.loca_model.to(device)
143
  self.track_model.to(device)
144
 
 
192
  adapted_emb = self.counting_adapter.adapter(loca_feature_bf_regression, boxes) # shape [1, 768]
193
 
194
  if task_loc_idx.shape[0] == 0:
195
+ encoder_hidden_states[0,2,:] = adapted_emb.squeeze()
196
  else:
197
+ encoder_hidden_states[:,task_loc_idx[0, 1]+1,:] = adapted_emb.squeeze()
198
 
199
  # Predict the noise residual
200
  noise_pred, feature_list = self.stable.unet(noisy_latents, timesteps, encoder_hidden_states)
 
213
 
214
  # only use 64x64 self-attention
215
  self_attn_aggregate = attn_utils.aggregate_attention( # [res, res, 4096]
216
+ prompts=[self.config.prompt for i in range(bsz)],
217
  attention_store=self.controller,
218
  res=64,
219
  from_where=("up", "down"),
 
221
  select=0
222
  )
223
  self_attn_aggregate32 = attn_utils.aggregate_attention( # [res, res, 4096]
224
+ prompts=[self.config.prompt for i in range(bsz)],
225
  attention_store=self.controller,
226
  res=32,
227
  from_where=("up", "down"),
 
229
  select=0
230
  )
231
  self_attn_aggregate16 = attn_utils.aggregate_attention( # [res, res, 4096]
232
+ prompts=[self.config.prompt for i in range(bsz)],
233
  attention_store=self.controller,
234
  res=16,
235
  from_where=("up", "down"),
 
240
  # cross attention
241
  for res in [32, 16]:
242
  attn_aggregate = attn_utils.aggregate_attention( # [res, res, 77]
243
+ prompts=[self.config.prompt for i in range(bsz)],
244
  attention_store=self.controller,
245
  res=res,
246
  from_where=("up", "down"),
 
250
 
251
  task_attn_ = attn_aggregate[:, :, 1].unsqueeze(0).unsqueeze(0) # [1, 1, res, res]
252
  attention_maps.append(task_attn_)
253
+ exemplar_attns = attn_aggregate[:, :, 2].unsqueeze(0).unsqueeze(0)
254
  exemplar_attention_maps.append(exemplar_attns)
255
 
256
 
 
277
  attn_stack = torch.cat(attn_stack, dim=1)
278
 
279
 
280
+ attn_after_new_regressor, loss = self.counting_adapter.regressor(input_image, attn_stack, feature_list, mask.cpu().numpy(), training=False)
281
 
282
  return {
283
  "attn_after_new_regressor":attn_after_new_regressor,
 
335
  adapted_emb = self.adapt_emb.to(self.device)
336
  task_loc_idx = torch.nonzero(input_ids == self.placeholder_token_id)
337
  if task_loc_idx.shape[0] == 0:
338
+ encoder_hidden_states[0,5,:] = adapted_emb.squeeze()
339
  else:
340
+ encoder_hidden_states[:,task_loc_idx[0, 1]+4,:] = adapted_emb.squeeze()
341
 
342
  # Predict the noise residual
343
  noise_pred, feature_list = self.stable.unet(noisy_latents, timesteps, encoder_hidden_states)
 
357
 
358
  # only use 64x64 self-attention
359
  self_attn_aggregate = attn_utils.aggregate_attention( # [res, res, 4096]
360
+ prompts=[self.config.prompt for i in range(bsz)],
361
  attention_store=self.controller,
362
  res=64,
363
  from_where=("up", "down"),
 
368
  # cross attention
369
  for res in [32, 16]:
370
  attn_aggregate = attn_utils.aggregate_attention( # [res, res, 77]
371
+ prompts=[self.config.prompt for i in range(bsz)],
372
  attention_store=self.controller,
373
  res=res,
374
  from_where=("up", "down"),
 
379
  task_attn_ = attn_aggregate[:, :, 1].unsqueeze(0).unsqueeze(0) # [1, 1, res, res]
380
  attention_maps.append(task_attn_)
381
  # if self.boxes is not None and not self.training:
382
+ exemplar_attns1 = attn_aggregate[:, :, 2].unsqueeze(0).unsqueeze(0)
383
  exemplar_attention_maps1.append(exemplar_attns1)
384
+ exemplar_attns2 = attn_aggregate[:, :, 3].unsqueeze(0).unsqueeze(0)
385
  exemplar_attention_maps2.append(exemplar_attns2)
386
+ exemplar_attns3 = attn_aggregate[:, :, 4].unsqueeze(0).unsqueeze(0)
387
  exemplar_attention_maps3.append(exemplar_attns3)
388
+ exemplar_attns4 = attn_aggregate[:, :, 5].unsqueeze(0).unsqueeze(0)
389
  exemplar_attention_maps4.append(exemplar_attns4)
390
 
391
 
 
511
 
512
  for n in range(n_forward):
513
  len_ = min(74, n_instance - n * 74)
514
+ encoder_hidden_states[:,(task_loc_idx[0, 1]+1):(task_loc_idx[0, 1]+1+len_),:] = adapted_emb[n*74:n*74+len_].squeeze()
 
515
 
516
 
517
  # Predict the noise residual
 
526
  # cross attention
527
  for res in [32, 16]:
528
  attn_aggregate = attn_utils.aggregate_attention( # [res, res, 77]
529
+ prompts=[self.config.prompt for i in range(bsz)],
530
  attention_store=self.controller,
531
  res=res,
532
  from_where=("up", "down"),
 
537
  task_attn_ = attn_aggregate[:, :, 1].unsqueeze(0).unsqueeze(0) # [1, 1, res, res]
538
  attention_maps.append(task_attn_)
539
  try:
540
+ exemplar_attns = attn_aggregate[:, :, (task_loc_idx[0, 1]+1):(task_loc_idx[0, 1]+1+len_)].unsqueeze(0)
541
  except:
542
  print(n_instance, len_)
543
  exemplar_attns = torch.permute(exemplar_attns, (0, 3, 1, 2)) # [1, len_, res, res]
 
698
 
699
  A = self.track_model.normalize_output(A, timepoints, coords)
700
 
 
 
 
 
 
701
  A = A.squeeze(0).detach().cpu().numpy()
702
 
703
  del feats, coords, timepoints, batch
 
985
  track_graph = self._track_from_predictions(predictions, mode=mode, **kwargs)
986
 
987
  return track_graph, masks