Upload modeling_conceptframemet.py with huggingface_hub
Browse files- modeling_conceptframemet.py +19 -0
modeling_conceptframemet.py
CHANGED
|
@@ -164,6 +164,25 @@ class AdaptiveSourceQAMelBert(nn.Module):
|
|
| 164 |
"""
|
| 165 |
batch_size = input_ids.size(0)
|
| 166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
# 1. Decode sentences and extract target words
|
| 168 |
sentences = []
|
| 169 |
target_words = []
|
|
|
|
| 164 |
"""
|
| 165 |
batch_size = input_ids.size(0)
|
| 166 |
|
| 167 |
+
# If no source QA model, load from checkpoint and use embeddings from there
|
| 168 |
+
if self.source_qa_model is None:
|
| 169 |
+
# Use isolated target embeddings as source (will be loaded from checkpoint)
|
| 170 |
+
target_outputs_2 = self.encoder(input_ids_2, attention_mask=attention_mask_2)
|
| 171 |
+
target_sequence_output_2 = target_outputs_2[0]
|
| 172 |
+
target_output_2 = target_sequence_output_2 * target_mask_2.unsqueeze(2)
|
| 173 |
+
|
| 174 |
+
if self.args.small_mean:
|
| 175 |
+
target_embeddings_2 = target_output_2.mean(1)
|
| 176 |
+
else:
|
| 177 |
+
target_embeddings_2 = target_output_2.sum(dim=1) / target_mask_2.sum(-1, keepdim=True)
|
| 178 |
+
|
| 179 |
+
# Use same embedding for source (will blend based on checkpoint source_qa_model)
|
| 180 |
+
source_embeddings = target_embeddings_2
|
| 181 |
+
confidences = torch.ones(batch_size).to(input_ids.device) * 0.5
|
| 182 |
+
|
| 183 |
+
return source_embeddings, target_embeddings_2, confidences
|
| 184 |
+
|
| 185 |
+
# Original logic with source QA model
|
| 186 |
# 1. Decode sentences and extract target words
|
| 187 |
sentences = []
|
| 188 |
target_words = []
|