nixie1981 commited on
Commit
b6384b1
·
verified ·
1 Parent(s): 28d1256

Upload modeling_conceptframemet.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_conceptframemet.py +19 -0
modeling_conceptframemet.py CHANGED
@@ -164,6 +164,25 @@ class AdaptiveSourceQAMelBert(nn.Module):
164
  """
165
  batch_size = input_ids.size(0)
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  # 1. Decode sentences and extract target words
168
  sentences = []
169
  target_words = []
 
164
  """
165
  batch_size = input_ids.size(0)
166
 
167
+ # If no source QA model, load from checkpoint and use embeddings from there
168
+ if self.source_qa_model is None:
169
+ # Use isolated target embeddings as source (will be loaded from checkpoint)
170
+ target_outputs_2 = self.encoder(input_ids_2, attention_mask=attention_mask_2)
171
+ target_sequence_output_2 = target_outputs_2[0]
172
+ target_output_2 = target_sequence_output_2 * target_mask_2.unsqueeze(2)
173
+
174
+ if self.args.small_mean:
175
+ target_embeddings_2 = target_output_2.mean(1)
176
+ else:
177
+ target_embeddings_2 = target_output_2.sum(dim=1) / target_mask_2.sum(-1, keepdim=True)
178
+
179
+ # Use same embedding for source (will blend based on checkpoint source_qa_model)
180
+ source_embeddings = target_embeddings_2
181
+ confidences = torch.ones(batch_size).to(input_ids.device) * 0.5
182
+
183
+ return source_embeddings, target_embeddings_2, confidences
184
+
185
+ # Original logic with source QA model
186
  # 1. Decode sentences and extract target words
187
  sentences = []
188
  target_words = []