huzey commited on
Commit
905cc0d
·
verified ·
1 Parent(s): 3f545d8

Upload folder using huggingface_hub

Browse files
ckpts/part1_voxel_indices.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81c686f552aba9902cd99d841502a2fd29cd9f2c325d9f47c3db1d54eb4e8862
3
+ size 3690757
ckpts/subj01_part1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:832b67cebae7d34ce51720ce63e5aa69e5a1dc4d20c7578e2481410db0d4d014
3
+ size 1012959757
ckpts/subj01_part2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f72e8ead92ca7c4f6ec8f10df9028000e71808992e71d433fe8326e5289a95f6
3
+ size 1228352781
ckpts/subj02_part1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:778f90969556033cf6f8b41c10055f7ed004dc72d8c1fa583feb21c0f159af5b
3
+ size 1010962957
ckpts/subj02_part2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9c114b7aa9ff2f13ab598d61cd7cd585c297a8a44a2c91ac9d5328474fd28f5
3
+ size 1228352781
ckpts/subj03_part1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:172ae65d4fb580b9a1d571c5d8d32cbeca5d1db81b241376b37389888dbf33e6
3
+ size 1009368973
ckpts/subj03_part2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ee8ea85a3454966f818e5d2b50ae44eff84f25c894cfb030bea86650b53cc94
3
+ size 1228352781
ckpts/subj04_part1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3230feb7ebbaf7f917c71b88a8633142002f9dd9ac0bf28593c646c3fb06edad
3
+ size 1014516813
ckpts/subj04_part2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97065ecede6d1d0d78282ce33b48eebb6d37add7c132960956c8fbbd25f22480
3
+ size 1228352781
ckpts/subj05_part1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6539e7b30d85c93aab16c26d608db31f47f82a4b4863cfccc4f23e077918e0f2
3
+ size 1011562637
ckpts/subj05_part2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef09dc8f97bd792007f376818fdabbca6969a982b07fc2f4e2dc43f6864b2e43
3
+ size 1228352781
ckpts/subj06_part1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f52293b73fbac7581a50d15a7680035efda909d64a0c89e51e912603ddb8a05a
3
+ size 1010378061
ckpts/subj06_part2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e84b9dc5a8d60db70303726dcb187705aecd8c312efec279d2bb57de71dbcc5
3
+ size 1228352781
ckpts/subj07_part1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b37b0778c0e7a8175c9f8b89722496048bd6ef1a824f5e12c880211177ae33b5
3
+ size 1015641869
ckpts/subj07_part2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea15a4dd91f4d755eccc7cdf83ea72ffda61eb70775922a3548ad37dc07bfe1e
3
+ size 1228352781
ckpts/subj08_part1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bccd9e8c68b47ed41f88f3491c5ef2618fd92c860c3ab92c612c256e1c1d46ec
3
+ size 1007108749
ckpts/subj08_part2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:508dd0f7ff9e1e4098a6550e89fefe842097bd9bd1510d3b4514300bb008119e
3
+ size 1228352781
config.yaml CHANGED
@@ -9,10 +9,10 @@ DATAMODULE:
9
  PIN_MEMORY: true
10
  DATASET:
11
  CACHE_DIR: /data/cache
12
- DARK_POSTFIX:
13
  FILTER_BY_SESSION:
14
  - -1
15
- FMRI_SPACE: full_fsaverage
16
  IMAGE_RESOLUTION:
17
  - 224
18
  - 224
@@ -22,13 +22,6 @@ DATASET:
22
  ROOT: /data/ALG23
23
  SUBJECT_LIST:
24
  - subj01
25
- - subj02
26
- - subj03
27
- - subj04
28
- - subj05
29
- - subj06
30
- - subj07
31
- - subj08
32
  DESCRIPTION: for alex
33
  EXPERIMENTAL:
34
  ANOTHER_SPLIT: false
@@ -53,8 +46,8 @@ EXPERIMENTAL:
53
  USE_FTR_BEHV: false
54
  LOSS:
55
  DARK:
56
- MAX_EPOCH: 150
57
- USE: false
58
  NAME: SmoothL1Loss
59
  SMOOTH_L1_BETA: 0.01
60
  SYNC:
@@ -67,7 +60,7 @@ LOSS:
67
  SKIP_EPOCHS: 20
68
  STAGE: VAL
69
  UPDATE_RULE: raw
70
- USE: false
71
  MODEL:
72
  BACKBONE:
73
  ADAPTIVE_LN:
@@ -150,13 +143,13 @@ MODEL_SOUP:
150
  RECIPE: greedy
151
  USE: true
152
  OPTIMIZER:
153
- LR: 0.0001
154
  NAME: AdamW
155
  SCHEDULER:
156
  CYCLE_DECAY: 0.5
157
  CYCLE_LIMIT: 3
158
  K_DECAY: 1.5
159
- LR_MIN: 0.0001
160
  LR_MIN_WARMUP: 0.0001
161
  T_INITIAL: 1
162
  T_MULT: 1.0
 
9
  PIN_MEMORY: true
10
  DATASET:
11
  CACHE_DIR: /data/cache
12
+ DARK_POSTFIX: xvdb
13
  FILTER_BY_SESSION:
14
  - -1
15
+ FMRI_SPACE: fsaverage
16
  IMAGE_RESOLUTION:
17
  - 224
18
  - 224
 
22
  ROOT: /data/ALG23
23
  SUBJECT_LIST:
24
  - subj01
 
 
 
 
 
 
 
25
  DESCRIPTION: for alex
26
  EXPERIMENTAL:
27
  ANOTHER_SPLIT: false
 
46
  USE_FTR_BEHV: false
47
  LOSS:
48
  DARK:
49
+ MAX_EPOCH: 50
50
+ USE: true
51
  NAME: SmoothL1Loss
52
  SMOOTH_L1_BETA: 0.01
53
  SYNC:
 
60
  SKIP_EPOCHS: 20
61
  STAGE: VAL
62
  UPDATE_RULE: raw
63
+ USE: true
64
  MODEL:
65
  BACKBONE:
66
  ADAPTIVE_LN:
 
143
  RECIPE: greedy
144
  USE: true
145
  OPTIMIZER:
146
+ LR: 0.0003
147
  NAME: AdamW
148
  SCHEDULER:
149
  CYCLE_DECAY: 0.5
150
  CYCLE_LIMIT: 3
151
  K_DECAY: 1.5
152
+ LR_MIN: 0.0003
153
  LR_MIN_WARMUP: 0.0001
154
  T_INITIAL: 1
155
  T_MULT: 1.0
drafts/cpfiles.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import os
3
+ # %%
4
+ source_dir = "/nfscc/alg23/xalex_distill2/wb/"
5
+ path_contains = "DARK.MAX_EPOCH=90"
6
+ search_filename = "soup.pth"
7
+ # %%
8
+ files = []
9
+ import glob
10
+ for file in glob.glob(os.path.join(source_dir, f"*{path_contains}*", search_filename)):
11
+ files.append(file)
12
+ files = sorted(files)
13
+ print(files)
14
+ # %%
15
+ target_dir = "/workspace/model_packed2/ckpts/"
16
+
17
+ for i_subj, file in enumerate(files):
18
+ print(f"copy {file} to {target_dir}")
19
+ target_filename = f"subj{i_subj+1:02d}_part2.pth"
20
+ target_path = os.path.join(target_dir, target_filename)
21
+ os.system(f"cp {file} {target_path}")
22
+ # %%
drafts/cpvoxel.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import numpy as np
3
+ import torch
4
+ import os
5
+ # %%
6
+ path = "/data/ALG23/subj01/data_mask/fsaverage/voxel_indices.npy"
7
+ part1_voxel_indice_dict = {}
8
+ for i in range(1, 9):
9
+ part1_voxel_indice_dict[f'subj0{i}'] = np.load(f"/data/ALG23/subj0{i}/data_mask/fsaverage/voxel_indices.npy")
10
+ torch.save(part1_voxel_indice_dict, "/nfscc/alg23/xalex_distill2/high/voxel_indices_dict.pth")
11
+ # %%
12
+ part1_voxel_indice_dict = torch.load("/nfscc/alg23/xalex_distill2/high/voxel_indices_dict.pth")
13
+ # %%
14
+ part1_voxel_indice_dict['subj01']
15
+ # %%
drafts/try_load.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ path = "/nfscc/alg23/xalex_distill2/high/t826c6_00016_DATASET.SUBJECT_LIST=subj01,LOSS.DARK.MAX_EPOCH=90,/soup.pth"
3
+ import torch
4
+ sd = torch.load(path, map_location='cpu')
5
+ print(sd.keys())
6
+ # %%
7
+ sd['coord_dict.subj01'].shape
8
+ # %%
9
+ sd['model.voxel_outs_weight.subj01.weight'].shape
10
+ # %%
example.py CHANGED
@@ -1,13 +1,22 @@
1
- from model import BrainEncodingModel
2
  from config_utils import load_from_yaml
3
  import torch
4
 
5
- cfg = load_from_yaml("./config.yaml")
6
- model = BrainEncodingModel(cfg)
7
- sd_path = './ckpt.pth'
8
- sd = torch.load(sd_path)
9
- model.load_state_dict(sd)
10
- model.eval().cuda()
 
 
 
 
 
 
 
 
 
11
 
12
  x = torch.randn(1, 3, 224, 224)
13
  def transform_image(x):
@@ -18,10 +27,8 @@ def transform_image(x):
18
  x = transform_image(x)
19
  x = x.cuda()
20
 
21
-
22
- subject = 'subj01' # could be 1 of 8 subjects
23
 
24
  with torch.no_grad():
25
- out = model(x, subject)
26
  print(out.shape)
27
  # torch.Size([1, 327684])
 
1
+ from model import _load_one_model, TowPartModel, BrainEncodingModel
2
  from config_utils import load_from_yaml
3
  import torch
4
 
5
+ subject = 'subj01'
6
+ cfg_path = "/workspace/model_packed2/config.yaml"
7
+ model_path1 = f"/workspace/model_packed2/ckpts/{subject}_part1.pth"
8
+ model_path2 = f"/workspace/model_packed2/ckpts/{subject}_part2.pth"
9
+ # model1 is for vertices with high noise ceiling (nsdgeneral)
10
+ # model2 is for vertices from the rest of the brain
11
+ model1: BrainEncodingModel = _load_one_model(model_path1, subject, cfg_path)
12
+ model2: BrainEncodingModel = _load_one_model(model_path2, subject, cfg_path)
13
+ # voxel_indices is a list of indices of vertices with high noise ceiling (for model1)
14
+ voxel_indices_path = "/workspace/model_packed2/ckpts/part1_voxel_indices.pt"
15
+ voxel_indices = torch.load(voxel_indices_path)[subject]
16
+ model = TowPartModel(model1, model2, voxel_indices)
17
+
18
+ model = model.cuda().eval()
19
+
20
 
21
  x = torch.randn(1, 3, 224, 224)
22
  def transform_image(x):
 
27
  x = transform_image(x)
28
  x = x.cuda()
29
 
 
 
30
 
31
  with torch.no_grad():
32
+ out = model(x)
33
  print(out.shape)
34
  # torch.Size([1, 327684])
model.py CHANGED
@@ -1,37 +1,26 @@
 
1
  from functools import partial
2
  import logging
3
- from torch import nn, Tensor
4
  from einops import rearrange, repeat
5
  from typing import Dict, Optional, Union
6
 
7
  import torch
8
  import torch.nn.functional as F
 
9
  from config import AutoConfig
10
 
11
  from backbone import (
12
- SubjectTimeEmbed,
13
  build_backbone,
14
  AdaLNLoRADiNOv2ViT,
15
- build_backbone_prev,
16
- build_time_emd,
17
  )
18
  from blocks import (
19
- PreviousFeatureMLPs,
20
- SubjectPreviousFrameCompress,
21
- build_class_token_mlp_prev,
22
  build_conv_blocks,
23
  build_class_token_mlp,
24
  DictConvBlocks,
25
  ClassTokenMLPs,
26
- build_ftr_compress,
27
- build_prev_compress,
28
- build_prev_feat_mlp,
29
  )
30
- from behav_embed import build_behavior_embed, SubjectBehaviorEmbed
31
  from config_utils import load_from_yaml
32
  from topyneck import (
33
- CoordsFreeWeights,
34
- build_coords_free_weights,
35
  build_coords_mlp,
36
  CachedCoordsMLP,
37
  build_voxelouts_weight,
@@ -41,39 +30,16 @@ from topyneck import (
41
 
42
  import numpy as np
43
 
44
- def get_coords():
45
- import nilearn
46
- from nilearn import datasets, surface
47
-
48
- fsaverage = nilearn.datasets.fetch_surf_fsaverage("fsaverage7")
49
- lh_coords, lh_faces = nilearn.surface.load_surf_mesh(fsaverage["sphere_left"])
50
- rh_coords, rh_faces = nilearn.surface.load_surf_mesh(fsaverage["sphere_right"])
51
- lh_xmin, lh_xmax = np.min(lh_coords[:, 0]), np.max(lh_coords[:, 0])
52
- lh_xmax = lh_xmin + (lh_xmax - lh_xmin) * 1.5
53
- rh_xmin, rh_xmax = np.min(rh_coords[:, 0]), np.max(rh_coords[:, 0])
54
- if rh_xmin < lh_xmax:
55
- rh_coords[:, 0] += lh_xmax - rh_xmin
56
- coords = np.concatenate((lh_coords, rh_coords), axis=0)
57
- coords = torch.tensor(coords)
58
- return coords
59
-
60
- # %%
61
  class BrainEncodingModel(nn.Module):
62
  def __init__(
63
  self,
64
  cfg: AutoConfig,
 
65
  ):
66
 
67
  super().__init__()
68
- n_voxel_dict = {'subj01': 327684,
69
- 'subj02': 327684,
70
- 'subj03': 327684,
71
- 'subj04': 327684,
72
- 'subj05': 327684,
73
- 'subj06': 327684,
74
- 'subj07': 327684,
75
- 'subj08': 327684}
76
  self.subject_list = list(n_voxel_dict.keys())
 
77
 
78
  self.layers = cfg.MODEL.BACKBONE.LAYERS
79
  self.layers_small = cfg.MODEL.BACKBONE_SMALL.LAYERS
@@ -82,19 +48,13 @@ class BrainEncodingModel(nn.Module):
82
  cfg.MODEL.CONV_HEAD.WIDTH = int(cfg.MODEL.CONV_HEAD.WIDTH * r)
83
  self.cfg = cfg
84
 
85
- self.behav_embed: SubjectBehaviorEmbed = build_behavior_embed(cfg)
86
- # behavior is not used, just a placeholder
87
-
88
  self.backbone: AdaLNLoRADiNOv2ViT = build_backbone(cfg)
89
  self.conv_blocks: DictConvBlocks = build_conv_blocks(cfg)
90
  self.cls_blocks: ClassTokenMLPs = build_class_token_mlp(cfg)
91
 
92
  def build_each_subject(fn, subject_list):
93
  return nn.ModuleDict({subject: fn() for subject in subject_list})
94
-
95
- self.coords = get_coords() # [327684, 3], for layer selector and retina mapper
96
- self.coords = nn.Parameter(self.coords, requires_grad=False)
97
-
98
  self.layer_selector: Dict[str, CachedCoordsMLP] = build_each_subject(
99
  partial(
100
  build_coords_mlp,
@@ -131,25 +91,26 @@ class BrainEncodingModel(nn.Module):
131
  for subject in self.subject_list
132
  }
133
  )
 
 
134
 
135
 
136
  def forward(
137
  self,
138
  x: Tensor, # [B, C, H, W]
139
- subject: str,
140
  voxel_indices: Optional[Tensor] = None,
141
  chunk_size=4096,
 
142
  ):
 
 
 
143
  bsz = x.shape[0]
144
  device = x.device
145
  dtype = x.dtype
146
 
147
- # bhv is not used, just a placeholder
148
- bhv = torch.zeros((bsz, self.cfg.MODEL.COND.IN_DIM), device=device, dtype=dtype) # [B, D_B=35]
149
- c = self.behav_embed(bhv, subject=subject) # [B, D_C]
150
-
151
  x_retina_grid, x_cls_dict = self.backbone.get_intermediate_layers(
152
- x, n=self.layers, c=c
153
  )
154
  x_retina_grid = self.conv_blocks(x_retina_grid)
155
  x_cls_dict = self.cls_blocks(x_cls_dict)
@@ -161,7 +122,6 @@ class BrainEncodingModel(nn.Module):
161
  #############################
162
 
163
  # divide voxels into chunks to avoid OOM
164
- coords = self.coords
165
  n_voxels = coords.shape[0]
166
  if voxel_indices is None or voxel_indices == ...:
167
  voxel_indices = torch.arange(n_voxels, device=coords.device)
@@ -186,7 +146,7 @@ class BrainEncodingModel(nn.Module):
186
  reg_layer = torch.cat(reg_layers, dim=0).mean() # [1]
187
 
188
  # if self.training:
189
- # return out_y, reg_layer
190
  # else:
191
  return out_y
192
 
@@ -260,3 +220,64 @@ class BrainEncodingModel(nn.Module):
260
 
261
  return out_y, reg_layer # [B, N], [N]
262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
  from functools import partial
3
  import logging
 
4
  from einops import rearrange, repeat
5
  from typing import Dict, Optional, Union
6
 
7
  import torch
8
  import torch.nn.functional as F
9
+ from torch import nn, Tensor
10
  from config import AutoConfig
11
 
12
  from backbone import (
 
13
  build_backbone,
14
  AdaLNLoRADiNOv2ViT,
 
 
15
  )
16
  from blocks import (
 
 
 
17
  build_conv_blocks,
18
  build_class_token_mlp,
19
  DictConvBlocks,
20
  ClassTokenMLPs,
 
 
 
21
  )
 
22
  from config_utils import load_from_yaml
23
  from topyneck import (
 
 
24
  build_coords_mlp,
25
  CachedCoordsMLP,
26
  build_voxelouts_weight,
 
30
 
31
  import numpy as np
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  class BrainEncodingModel(nn.Module):
34
  def __init__(
35
  self,
36
  cfg: AutoConfig,
37
+ n_voxel_dict = {'subj01': 327684},
38
  ):
39
 
40
  super().__init__()
 
 
 
 
 
 
 
 
41
  self.subject_list = list(n_voxel_dict.keys())
42
+ assert len(self.subject_list) == 1, "Only one subject is supported"
43
 
44
  self.layers = cfg.MODEL.BACKBONE.LAYERS
45
  self.layers_small = cfg.MODEL.BACKBONE_SMALL.LAYERS
 
48
  cfg.MODEL.CONV_HEAD.WIDTH = int(cfg.MODEL.CONV_HEAD.WIDTH * r)
49
  self.cfg = cfg
50
 
 
 
 
51
  self.backbone: AdaLNLoRADiNOv2ViT = build_backbone(cfg)
52
  self.conv_blocks: DictConvBlocks = build_conv_blocks(cfg)
53
  self.cls_blocks: ClassTokenMLPs = build_class_token_mlp(cfg)
54
 
55
  def build_each_subject(fn, subject_list):
56
  return nn.ModuleDict({subject: fn() for subject in subject_list})
57
+
 
 
 
58
  self.layer_selector: Dict[str, CachedCoordsMLP] = build_each_subject(
59
  partial(
60
  build_coords_mlp,
 
91
  for subject in self.subject_list
92
  }
93
  )
94
+
95
+ self.coords : nn.Parameter = None
96
 
97
 
98
  def forward(
99
  self,
100
  x: Tensor, # [B, C, H, W]
 
101
  voxel_indices: Optional[Tensor] = None,
102
  chunk_size=4096,
103
+ **kwargs,
104
  ):
105
+ coords = self.coords
106
+ subject = self.subject_list[0]
107
+
108
  bsz = x.shape[0]
109
  device = x.device
110
  dtype = x.dtype
111
 
 
 
 
 
112
  x_retina_grid, x_cls_dict = self.backbone.get_intermediate_layers(
113
+ x, n=self.layers, c=None
114
  )
115
  x_retina_grid = self.conv_blocks(x_retina_grid)
116
  x_cls_dict = self.cls_blocks(x_cls_dict)
 
122
  #############################
123
 
124
  # divide voxels into chunks to avoid OOM
 
125
  n_voxels = coords.shape[0]
126
  if voxel_indices is None or voxel_indices == ...:
127
  voxel_indices = torch.arange(n_voxels, device=coords.device)
 
146
  reg_layer = torch.cat(reg_layers, dim=0).mean() # [1]
147
 
148
  # if self.training:
149
+ # return out_y, reg_layer
150
  # else:
151
  return out_y
152
 
 
220
 
221
  return out_y, reg_layer # [B, N], [N]
222
 
223
+
224
+
225
+ def _load_one_model(model_path: str, subject: str='subj01', cfg_path: str=None):
226
+ cfg = load_from_yaml(cfg_path)
227
+
228
+ # load model weights
229
+ sd = torch.load(model_path, map_location='cpu')
230
+ n_voxels = sd[f'model.voxel_outs_weight.{subject}.weight'].shape[0]
231
+ # create model
232
+ model = BrainEncodingModel(cfg, {subject: n_voxels})
233
+
234
+ # save voxel's coordinates to model
235
+ coords = sd[f'coord_dict.{subject}']
236
+ model.coords = nn.Parameter(coords)
237
+
238
+ # load weights
239
+ filtered_sd = {k: v for k, v in sd.items() if k.startswith('model')}
240
+ filtered_sd = {k[6:]: v for k, v in filtered_sd.items() if k.startswith('model')}
241
+ filtered_sd['coords'] = model.coords # add coordinates of voxels
242
+ model.load_state_dict(filtered_sd)
243
+
244
+ model = model.eval()
245
+ return model
246
+
247
+
248
+ class TowPartModel(nn.Module):
249
+ def __init__(self, model_part1, model_part2, part1_voxel_indices):
250
+ super().__init__()
251
+ self.model_part1 = model_part1
252
+ self.model_part2 = model_part2
253
+ self.part1_voxel_indices = part1_voxel_indices
254
+
255
+ def forward(self, x):
256
+ # x: [B, 3, 224, 224] # image after normalization
257
+ out1 = self.model_part1(x)
258
+ out2 = self.model_part2(x)
259
+ out = out2
260
+ out[:, self.part1_voxel_indices] = out1
261
+ return out
262
+
263
+
264
+
265
+
266
+ # %%
267
+ if __name__ == '__main__':
268
+ # model_path = "/nfscc/alg23/xalex_distill2/high/t826c6_00016_DATASET.SUBJECT_LIST=subj01,LOSS.DARK.MAX_EPOCH=90,/soup.pth"
269
+ subject = 'subj01'
270
+ cfg_path = "/workspace/model_packed2/config.yaml"
271
+ model_path1 = f"/workspace/model_packed2/ckpts/{subject}_part1.pth"
272
+ model_path2 = f"/workspace/model_packed2/ckpts/{subject}_part2.pth"
273
+ model1 = _load_one_model(model_path1, subject, cfg_path)
274
+ model2 = _load_one_model(model_path2, subject, cfg_path)
275
+ voxel_indices_path = "/workspace/model_packed2/ckpts/part1_voxel_indices.pt"
276
+ voxel_indices = torch.load(voxel_indices_path)[subject]
277
+ model = TowPartModel(model1, model2, voxel_indices)
278
+
279
+ x = torch.randn(1, 3, 224, 224)
280
+ x = x.cuda()
281
+ model = model.cuda()
282
+ out = model(x)
283
+ print(out.shape)