lukeingawesome commited on
Commit
6f83848
·
verified ·
1 Parent(s): dd5ed8b

Fix model structure to match checkpoint (wrap LlamaBiModel in self.model)

Browse files
Files changed (1) hide show
  1. modeling_llm2vec4cxr.py +18 -20
modeling_llm2vec4cxr.py CHANGED
@@ -3,6 +3,8 @@ Custom model class for LLM2Vec4CXR that properly handles latent attention poolin
3
  """
4
 
5
  from llm2vec.models.bidirectional_llama import LlamaBiModel
 
 
6
  # from llm2vec.pooling import LatentAttentionPooling
7
  from .pooling_latent import LatentAttentionPooling
8
  from transformers import AutoTokenizer
@@ -11,46 +13,42 @@ import torch.nn as nn
11
  import torch.nn.functional as F
12
 
13
 
14
- class LLM2Vec4CXRModel(LlamaBiModel):
15
  """
16
- Custom LlamaBiModel that includes latent attention pooling by default.
17
- This prevents the warning about unused latent attention weights.
18
  """
 
19
 
20
  def __init__(self, config, **kwargs):
21
  super().__init__(config, **kwargs)
22
 
 
 
 
23
  # Initialize latent attention pooling
24
  self.latent_attn = LatentAttentionPooling(
25
  d_model=config.hidden_size,
26
  num_heads=8, # Standard for this model size
27
  num_latents=512 # Standard for LLM2Vec
28
  )
29
-
30
- # Move to the same device/dtype as the base model
31
- if hasattr(self, 'model') and hasattr(self.model, 'embed_tokens'):
32
- device = self.model.embed_tokens.weight.device
33
- dtype = self.model.embed_tokens.weight.dtype
34
- self.latent_attn = self.latent_attn.to(device=device, dtype=dtype)
35
 
36
  def forward(self, input_ids, attention_mask=None, embed_mask=None, **kwargs):
37
  """
38
  Forward pass that properly handles latent attention pooling.
39
  """
40
  # Get base model output
41
- outputs = super().forward(input_ids, attention_mask=attention_mask, **kwargs)
42
 
43
- # If we have latent attention pooling, apply it
44
- if hasattr(self, 'latent_attn') and self.latent_attn is not None:
45
- if embed_mask is not None:
46
- # Use embed_mask for instruction-following tasks
47
- pooled_output = self.latent_attn(outputs.last_hidden_state, embed_mask)
48
- else:
49
- # Use attention_mask for simple encoding
50
- pooled_output = self.latent_attn(outputs.last_hidden_state, attention_mask)
51
- return pooled_output
52
 
53
- return outputs.last_hidden_state
54
 
55
  # --- Convenience tokenizer (lazy) -------------------------------------
56
  def _get_tokenizer(self):
 
3
  """
4
 
5
  from llm2vec.models.bidirectional_llama import LlamaBiModel
6
+ from transformers import PreTrainedModel
7
+ from transformers.models.llama.configuration_llama import LlamaConfig
8
  # from llm2vec.pooling import LatentAttentionPooling
9
  from .pooling_latent import LatentAttentionPooling
10
  from transformers import AutoTokenizer
 
13
  import torch.nn.functional as F
14
 
15
 
16
+ class LLM2Vec4CXRModel(PreTrainedModel):
17
  """
18
+ Wrapper model that includes LlamaBiModel and latent attention pooling.
19
+ Structure matches the saved checkpoint: self.model + self.latent_attn
20
  """
21
+ config_class = LlamaConfig
22
 
23
  def __init__(self, config, **kwargs):
24
  super().__init__(config, **kwargs)
25
 
26
+ # Wrap the LlamaBiModel
27
+ self.model = LlamaBiModel(config)
28
+
29
  # Initialize latent attention pooling
30
  self.latent_attn = LatentAttentionPooling(
31
  d_model=config.hidden_size,
32
  num_heads=8, # Standard for this model size
33
  num_latents=512 # Standard for LLM2Vec
34
  )
 
 
 
 
 
 
35
 
36
  def forward(self, input_ids, attention_mask=None, embed_mask=None, **kwargs):
37
  """
38
  Forward pass that properly handles latent attention pooling.
39
  """
40
  # Get base model output
41
+ outputs = self.model(input_ids, attention_mask=attention_mask, **kwargs)
42
 
43
+ # Apply latent attention pooling
44
+ if embed_mask is not None:
45
+ # Use embed_mask for instruction-following tasks
46
+ pooled_output = self.latent_attn(outputs.last_hidden_state, embed_mask)
47
+ else:
48
+ # Use attention_mask for simple encoding
49
+ pooled_output = self.latent_attn(outputs.last_hidden_state, attention_mask)
 
 
50
 
51
+ return pooled_output
52
 
53
  # --- Convenience tokenizer (lazy) -------------------------------------
54
  def _get_tokenizer(self):