Upload modeling_conceptframemet.py with huggingface_hub
Browse files
modeling_conceptframemet.py
CHANGED
|
@@ -68,13 +68,14 @@ class AdaptiveSourceQAMelBert(nn.Module):
|
|
| 68 |
self.source_alpha = getattr(args, 'source_alpha', 0.3)
|
| 69 |
self.metaphor_threshold = getattr(args, 'metaphor_threshold', 0.5)
|
| 70 |
|
| 71 |
-
# Freeze or unfreeze source QA model
|
| 72 |
-
if
|
| 73 |
-
|
| 74 |
-
param
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
param
|
|
|
|
| 78 |
|
| 79 |
# Load source labels
|
| 80 |
self.source_id2label = {}
|
|
|
|
| 68 |
self.source_alpha = getattr(args, 'source_alpha', 0.3)
|
| 69 |
self.metaphor_threshold = getattr(args, 'metaphor_threshold', 0.5)
|
| 70 |
|
| 71 |
+
# Freeze or unfreeze source QA model (only if it exists)
|
| 72 |
+
if self.source_qa_model is not None:
|
| 73 |
+
if not getattr(args, 'unfreeze_source_qa', False):
|
| 74 |
+
for param in self.source_qa_model.parameters():
|
| 75 |
+
param.requires_grad = False
|
| 76 |
+
else:
|
| 77 |
+
for param in self.source_qa_model.parameters():
|
| 78 |
+
param.requires_grad = True
|
| 79 |
|
| 80 |
# Load source labels
|
| 81 |
self.source_id2label = {}
|