ga89tiy commited on
Commit
f5ae994
1 Parent(s): 9de5cd4

last update

Browse files
LLAVA_Biovil/biovil_t/transformer.py CHANGED
@@ -93,17 +93,11 @@ def forward(self, current_image: torch.Tensor, previous_image: Optional[torch.Te
93
 
94
  def forward_after_reshape(self,
95
  x: torch.Tensor,
96
- pos_embed: torch.Tensor,
97
- x_previous: Optional[torch.Tensor] = None) -> torch.Tensor:
98
  B, L, _ = x.shape # Batch, Sequence length, Feature dimension
99
 
100
  # Positional and type embeddings
101
  type_embed = self.type_embed[0].expand(B, L, -1)
102
- if x_previous is not None:
103
- x = torch.cat((x, x_previous), dim=1)
104
- pos_embed = torch.cat((pos_embed, pos_embed), dim=1)
105
- prev_type_embed = self.type_embed[1].expand(B, L, -1)
106
- type_embed = torch.cat((type_embed, prev_type_embed), dim=1)
107
 
108
  # Add positional and type embeddings (used in query and key matching)
109
  pos_and_type_embed = pos_embed + type_embed
 
93
 
94
  def forward_after_reshape(self,
95
  x: torch.Tensor,
96
+ pos_embed: torch.Tensor) -> torch.Tensor:
 
97
  B, L, _ = x.shape # Batch, Sequence length, Feature dimension
98
 
99
  # Positional and type embeddings
100
  type_embed = self.type_embed[0].expand(B, L, -1)
 
 
 
 
 
101
 
102
  # Add positional and type embeddings (used in query and key matching)
103
  pos_and_type_embed = pos_embed + type_embed
LLAVA_Biovil/llava/model/builder.py CHANGED
@@ -184,19 +184,6 @@ def load_from_hf(repo_id, filename, subfolder=None):
184
  new_vision_tower_state_dict[new_k] = v
185
  print('Loaded additional vision tower weights...')
186
  vision_tower.load_state_dict(new_vision_tower_state_dict, strict=False)
187
- # weight difference sum([torch.norm(value-vision_tower.state_dict()[key].cpu()) for key,value in new_vision_tower_state_dict.items()])
188
-
189
- image_pooler = model.get_image_pooler()
190
- if image_pooler is not None:
191
- image_pooler.to(device=device, dtype=torch.float16)
192
- if non_lora_trainables is not None and any(k.startswith('model.image_pooler.') for k in non_lora_trainables):
193
- new_image_pooler_state_dict = {}
194
- for k, v in non_lora_trainables.items(): # we need remapping, because state_dict from model is always like model.vision_tower. It should be vision_tower.
195
- if 'model.image_pooler.' in k:
196
- new_k = k.replace('model.image_pooler.', '')
197
- new_image_pooler_state_dict[new_k] = v
198
- print('Loading additional image pooler weights...')
199
- image_pooler.load_state_dict(new_image_pooler_state_dict, strict=True)
200
 
201
  if hasattr(model.config, "max_sequence_length"):
202
  context_len = model.config.max_sequence_length
 
184
  new_vision_tower_state_dict[new_k] = v
185
  print('Loaded additional vision tower weights...')
186
  vision_tower.load_state_dict(new_vision_tower_state_dict, strict=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  if hasattr(model.config, "max_sequence_length"):
189
  context_len = model.config.max_sequence_length
LLAVA_Biovil/llava/model/language_model/llava_llama.py CHANGED
@@ -35,20 +35,19 @@ class LlavaConfig(LlamaConfig):
35
  class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
36
  config_class = LlavaConfig
37
 
38
- def __init__(self, config: LlamaConfig, mv_type='none'):
39
- super(LlavaLlamaModel, self).__init__(config, mv_type=mv_type)
40
 
41
 
42
  class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
43
  config_class = LlavaConfig
44
 
45
- def __init__(self, config, mv_type='none'):
46
  super(LlamaForCausalLM, self).__init__(config)
47
- self.model = LlavaLlamaModel(config, mv_type=mv_type)
48
  self.pretraining_tp = config.pretraining_tp
49
  self.vocab_size = config.vocab_size
50
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
51
- self.mv_type = mv_type
52
 
53
  # Initialize weights and apply final processing
54
  self.post_init()
@@ -68,7 +67,6 @@ def forward(
68
  output_attentions: Optional[bool] = None,
69
  output_hidden_states: Optional[bool] = None,
70
  images: Optional[torch.FloatTensor] = None,
71
- prev_images: Optional[torch.FloatTensor] = None,
72
  return_dict: Optional[bool] = None,
73
  ) -> Union[Tuple, CausalLMOutputWithPast]:
74
  if inputs_embeds is None:
@@ -85,8 +83,7 @@ def forward(
85
  attention_mask,
86
  past_key_values,
87
  labels,
88
- images,
89
- prev_images
90
  )
91
  output = super().forward(
92
  input_ids=input_ids,
 
35
  class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
36
  config_class = LlavaConfig
37
 
38
+ def __init__(self, config: LlamaConfig):
39
+ super(LlavaLlamaModel, self).__init__(config)
40
 
41
 
42
  class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
43
  config_class = LlavaConfig
44
 
45
+ def __init__(self, config):
46
  super(LlamaForCausalLM, self).__init__(config)
47
+ self.model = LlavaLlamaModel(config)
48
  self.pretraining_tp = config.pretraining_tp
49
  self.vocab_size = config.vocab_size
50
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
51
 
52
  # Initialize weights and apply final processing
53
  self.post_init()
 
67
  output_attentions: Optional[bool] = None,
68
  output_hidden_states: Optional[bool] = None,
69
  images: Optional[torch.FloatTensor] = None,
 
70
  return_dict: Optional[bool] = None,
71
  ) -> Union[Tuple, CausalLMOutputWithPast]:
72
  if inputs_embeds is None:
 
83
  attention_mask,
84
  past_key_values,
85
  labels,
86
+ images
 
87
  )
88
  output = super().forward(
89
  input_ids=input_ids,
LLAVA_Biovil/llava/model/llava_arch.py CHANGED
@@ -27,13 +27,12 @@
27
 
28
  class LlavaMetaModel:
29
 
30
- def __init__(self, config, mv_type='none'):
31
  super(LlavaMetaModel, self).__init__(config)
32
 
33
  if hasattr(config, "mm_vision_tower"):
34
  self.vision_tower = build_vision_tower(config, delay_load=True)
35
  self.mm_projector = build_vision_projector(config)
36
- self.image_pooler = build_image_pooler(config) if "pool" in mv_type else None
37
 
38
  def get_vision_tower(self):
39
  vision_tower = getattr(self, 'vision_tower', None)
@@ -51,7 +50,6 @@ def initialize_vision_modules(self, model_args, fsdp=None):
51
  pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
52
 
53
  self.config.mm_vision_tower = vision_tower
54
- self.config.mv_type = getattr(model_args, 'mv_type', False)
55
 
56
  if self.get_vision_tower() is None:
57
  if self.config.mm_vision_tower == 'biovil':
@@ -188,87 +186,8 @@ def pad_embeddings_mv(self, embeddings, padding_value=0):
188
 
189
  return padded_embeddings.flatten(1,2), mask
190
 
191
- def encode_images_pooled(self, images, split_sizes, num_imgs_present, num_imgs_past, mv_type="pool_all"):
192
- image_pooler = self.get_image_pooler()
193
- image_features = self.get_model().get_vision_tower()(images)
194
- if self.get_model().config.mm_vision_tower == 'biovil':
195
- image_features = image_features.patch_embeddings
196
- # flatten
197
- image_features = image_features.flatten(2).transpose(1,2)
198
- if split_sizes is not None:
199
- image_features = torch.split(image_features, split_sizes, dim=0)
200
-
201
- if mv_type == "pool_all":
202
- # merge present and past per batch
203
- present_features = [image_features[i] for i in range(len(num_imgs_present))]
204
- past_features = []
205
- i = 0
206
- for num_imgs_elem in num_imgs_past:
207
- if num_imgs_elem != 0:
208
- past_features.append(image_features[i+len(num_imgs_present)])
209
- i += 1
210
- else:
211
- past_features.append(None)
212
-
213
- all_img_features = []
214
- for idx, (batch_num_present, batch_num_past) in enumerate(zip(num_imgs_present, num_imgs_past)):
215
- if batch_num_past == 0:
216
- all_img_features.append(present_features[idx])
217
- else:
218
- all_img_features.append(torch.cat((present_features[idx], past_features[idx]), dim=0))
219
-
220
- all_img_features, mask, token_type_ids = self.pad_embeddings(all_img_features, num_imgs_present, num_imgs_past)
221
- all_img_features = image_pooler(all_img_features, mask, token_type_ids)
222
-
223
- elif mv_type == "pool_concat":
224
- present_features = [image_features[i] for i in range(len(num_imgs_present))]
225
- past_features = [image_features[i+len(num_imgs_present)] for i in range(len(image_features)-len(num_imgs_present))]
226
- present_features, mask_present, _ = self.pad_embeddings(present_features)
227
- past_features, mask_past, _ = self.pad_embeddings(past_features)
228
- present_features = image_pooler(present_features, mask_present)
229
- past_features = image_pooler(past_features, mask_past)
230
- # TODO maybe max pool on past features to save tokens
231
- # concat present and past per batch if past is not empty
232
- all_img_features = []
233
- idx_present = 0
234
- idx_past = 0
235
- for batch_num_present, batch_num_past in zip(num_imgs_present, num_imgs_past):
236
- if batch_num_past == 0:
237
- all_img_features.append(present_features[idx_present])
238
- idx_present += 1
239
- else:
240
- all_img_features.append(torch.cat((present_features[idx_present], past_features[idx_past]), dim=0))
241
- idx_present += 1
242
- idx_past += 1
243
- else:
244
- raise NotImplementedError
245
- if type(all_img_features) is list:
246
- split_sizes = [image.shape[0] for image in all_img_features]
247
- all_img_features = self.get_model().mm_projector(torch.cat(all_img_features, dim=0))
248
- all_img_features = torch.split(all_img_features, split_sizes, dim=0)
249
-
250
- else:
251
- all_img_features = self.get_model().mm_projector(all_img_features)
252
- return all_img_features
253
-
254
- def encode_images_pooled_mv(self, images, split_sizes):
255
- image_pooler = self.get_image_pooler()
256
- image_features = self.get_model().get_vision_tower()(images)
257
- if split_sizes is not None:
258
- image_features = torch.split(image_features, split_sizes, dim=0)
259
- image_features, mask = self.pad_embeddings_mv(image_features)
260
- image_features = image_pooler(image_features, mask)
261
- else:
262
- mask = torch.ones((image_features.shape[0], image_features.shape[1]), dtype=torch.bool, device=image_features[0].device)
263
- image_features = image_pooler(image_features, mask)
264
- image_features = self.get_model().mm_projector(image_features)
265
- return image_features
266
-
267
- def get_image_pooler(self):
268
- return self.get_model().get_image_pooler()
269
-
270
  def prepare_inputs_labels_for_multimodal(
271
- self, input_ids, position_ids, attention_mask, past_key_values, labels, images, prev_images=None
272
  ):
273
  vision_tower = self.get_vision_tower()
274
  if vision_tower is None or images is None or input_ids.shape[1] == 1:
@@ -283,35 +202,14 @@ def prepare_inputs_labels_for_multimodal(
283
  return input_ids, position_ids, attention_mask, past_key_values, None, labels
284
 
285
  if type(images) is list or images.ndim == 5:
286
- if getattr(self.config, 'mv_type') == "concat":
287
- concat_images = torch.cat([image for image in images], dim=0)
288
- image_features = self.encode_images(concat_images)
289
- split_sizes = [image.shape[0] for image in images]
290
- image_features = torch.split(image_features, split_sizes, dim=0)
291
- image_features = [x.flatten(0, 1).to(self.device) for x in image_features]
292
- if getattr(self.config, 'mv_type') == "pool_all":
293
- concat_images = torch.cat((torch.cat([image for image in images], dim=0), torch.cat([image for image in prev_images if image is not None], dim=0))) # first present, then past, all will be merged
294
- split_sizes = [image.shape[0] for image in images]+ [image.shape[0] for image in prev_images if image is not None]
295
- num_imgs_present = [image.shape[0] if image is not None else 0 for image in images]
296
- num_imgs_past = [image.shape[0] if image is not None else 0 for image in prev_images]
297
- image_features = self.encode_images_pooled(concat_images, split_sizes, num_imgs_present, num_imgs_past, "pool_all")
298
- if getattr(self.config, 'mv_type') == "pool_concat": # TODO make sure to allow empty past -> shorter sequence
299
- concat_images = torch.cat((torch.cat([image for image in images], dim=0), torch.cat([image for image in prev_images if image is not None], dim=0))) # first present, then past, all will be merged
300
- split_sizes = [image.shape[0] for image in images]+ [image.shape[0] for image in prev_images if image is not None]
301
- num_imgs_present = [image.shape[0] if image is not None else 0 for image in images]
302
- num_imgs_past = [image.shape[0] if image is not None else 0 for image in prev_images]
303
- image_features = self.encode_images_pooled(concat_images, split_sizes, num_imgs_present, num_imgs_past, "pool_concat")
304
- if getattr(self.config, 'mv_type') == "pool": #no past images
305
- concat_images = torch.cat([image for image in images], dim=0)
306
- split_sizes = [image.shape[0] for image in images]
307
- image_features = self.encode_images_pooled_mv(concat_images, split_sizes)
308
  else:
309
- if hasattr(self.config, 'mv_type') and getattr(self.config, 'mv_type') == "pool_all":
310
- image_features = self.encode_images_pooled(images, None).to(self.device)
311
- elif hasattr(self.config, 'mv_type') and getattr(self.config, 'mv_type') == "pool":
312
- image_features = self.encode_images_pooled_mv(images, None).to(self.device)
313
- else:
314
- image_features = self.encode_images(images).to(self.device)
315
 
316
  # TODO: image start / end is not implemented here to support pretraining.
317
  if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
 
27
 
28
  class LlavaMetaModel:
29
 
30
+ def __init__(self, config):
31
  super(LlavaMetaModel, self).__init__(config)
32
 
33
  if hasattr(config, "mm_vision_tower"):
34
  self.vision_tower = build_vision_tower(config, delay_load=True)
35
  self.mm_projector = build_vision_projector(config)
 
36
 
37
  def get_vision_tower(self):
38
  vision_tower = getattr(self, 'vision_tower', None)
 
50
  pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
51
 
52
  self.config.mm_vision_tower = vision_tower
 
53
 
54
  if self.get_vision_tower() is None:
55
  if self.config.mm_vision_tower == 'biovil':
 
186
 
187
  return padded_embeddings.flatten(1,2), mask
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  def prepare_inputs_labels_for_multimodal(
190
+ self, input_ids, position_ids, attention_mask, past_key_values, labels, images
191
  ):
192
  vision_tower = self.get_vision_tower()
193
  if vision_tower is None or images is None or input_ids.shape[1] == 1:
 
202
  return input_ids, position_ids, attention_mask, past_key_values, None, labels
203
 
204
  if type(images) is list or images.ndim == 5:
205
+ concat_images = torch.cat([image for image in images], dim=0)
206
+ image_features = self.encode_images(concat_images)
207
+ split_sizes = [image.shape[0] for image in images]
208
+ image_features = torch.split(image_features, split_sizes, dim=0)
209
+ image_features = [x.flatten(0, 1).to(self.device) for x in image_features]
210
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  else:
212
+ image_features = self.encode_images(images).to(self.device)
 
 
 
 
 
213
 
214
  # TODO: image start / end is not implemented here to support pretraining.
215
  if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.68 kB). View file
 
example_code.py CHANGED
@@ -6,7 +6,7 @@ import requests
6
  import torch
7
  from PIL import Image
8
  import numpy as np
9
- from huggingface_hub import snapshot_download
10
 
11
  from LLAVA_Biovil.llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, remap_to_uint8
12
  from LLAVA_Biovil.llava.model.builder import load_pretrained_model
@@ -18,13 +18,12 @@ from utils import create_chest_xray_transform_for_inference, init_chexpert_predi
18
 
19
  def load_model_from_huggingface(repo_id):
20
  # Download model files
21
- model_path = snapshot_download(repo_id=repo_id, revision="main", force_download=True)
22
  model_path = Path(model_path)
23
 
24
  tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base='liuhaotian/llava-v1.5-7b',
25
  model_name="llava-v1.5-7b-task-lora_radialog_instruct_llava_biovil_unfrozen_2e-5_5epochs_v5_checkpoint-21000", load_8bit=False, load_4bit=False)
26
 
27
-
28
  return tokenizer, model, image_processor, context_len
29
 
30
 
@@ -37,7 +36,7 @@ if __name__ == '__main__':
37
  image = remap_to_uint8(np.array(image))
38
  image = Image.fromarray(image).convert("L")
39
 
40
- tokenizer, model, image_processor, context_len = load_model_from_huggingface(repo_id="Chantal/RaDialog-interactive-radiology-report-generation")
41
  cp_model, cp_class_names, cp_transforms = init_chexpert_predictor()
42
 
43
  model.config.tokenizer_padding_side = "left"
@@ -82,27 +81,6 @@ if __name__ == '__main__':
82
  pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
83
  print("ASSISTANT: ", pred)
84
 
85
- # add prediction to conversation
86
- conv.messages.pop()
87
- conv.append_message("ASSISTANT", pred)
88
- stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
89
- stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids)
90
-
91
- # generate a report
92
- with torch.inference_mode():
93
- output_ids = model.generate(
94
- input_ids,
95
- images=image_tensor,
96
- do_sample=False,
97
- use_cache=True,
98
- max_new_tokens=300,
99
- stopping_criteria=[stopping_criteria],
100
- pad_token_id=tokenizer.pad_token_id
101
- )
102
-
103
- pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
104
- print("ASSISTANT: ", pred)
105
-
106
  # add prediction to conversation
107
  conv.messages.pop()
108
  conv.append_message("ASSISTANT", pred)
 
6
  import torch
7
  from PIL import Image
8
  import numpy as np
9
+ from huggingface_hub import snapshot_download, hf_hub_download
10
 
11
  from LLAVA_Biovil.llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, remap_to_uint8
12
  from LLAVA_Biovil.llava.model.builder import load_pretrained_model
 
18
 
19
  def load_model_from_huggingface(repo_id):
20
  # Download model files
21
+ model_path = snapshot_download(repo_id=repo_id, revision="main")
22
  model_path = Path(model_path)
23
 
24
  tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base='liuhaotian/llava-v1.5-7b',
25
  model_name="llava-v1.5-7b-task-lora_radialog_instruct_llava_biovil_unfrozen_2e-5_5epochs_v5_checkpoint-21000", load_8bit=False, load_4bit=False)
26
 
 
27
  return tokenizer, model, image_processor, context_len
28
 
29
 
 
36
  image = remap_to_uint8(np.array(image))
37
  image = Image.fromarray(image).convert("L")
38
 
39
+ tokenizer, model, image_processor, context_len = load_model_from_huggingface(repo_id="ChantalPellegrini/RaDialog-interactive-radiology-report-generation")
40
  cp_model, cp_class_names, cp_transforms = init_chexpert_predictor()
41
 
42
  model.config.tokenizer_padding_side = "left"
 
81
  pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
82
  print("ASSISTANT: ", pred)
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  # add prediction to conversation
85
  conv.messages.pop()
86
  conv.append_message("ASSISTANT", pred)
findings_classifier/__pycache__/chexpert_train.cpython-310.pyc CHANGED
Binary files a/findings_classifier/__pycache__/chexpert_train.cpython-310.pyc and b/findings_classifier/__pycache__/chexpert_train.cpython-310.pyc differ