nixie1981 commited on
Commit
748a347
·
verified ·
1 Parent(s): 939da56

Upload modeling_conceptframemet.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_conceptframemet.py +32 -4
modeling_conceptframemet.py CHANGED
@@ -55,6 +55,18 @@ class AdaptiveSourceQAMelBert(nn.Module):
55
  super(AdaptiveSourceQAMelBert, self).__init__()
56
  self.num_labels = num_labels
57
  self.encoder = Model
 
 
 
 
 
 
 
 
 
 
 
 
58
  self.source_qa_model = Source_QA_Model
59
  self.source_qa_tokenizer = source_qa_tokenizer
60
  self.melbert_tokenizer = melbert_tokenizer
@@ -81,10 +93,26 @@ class AdaptiveSourceQAMelBert(nn.Module):
81
  self.source_id2label = {}
82
  try:
83
  import json
84
- with open('source_finder/source_labels.json', 'r') as f:
85
- source_label2id = json.load(f)
86
- self.source_id2label = {v: k for k, v in source_label2id.items()}
87
- print(f"✓ Loaded {len(self.source_id2label)} source domain labels")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  except Exception as e:
89
  print(f"❌ Warning: Could not load source labels: {e}")
90
 
 
55
  super(AdaptiveSourceQAMelBert, self).__init__()
56
  self.num_labels = num_labels
57
  self.encoder = Model
58
+
59
+ # FIX: Resize token_type_embeddings to match training (type_vocab_size=4)
60
+ if hasattr(self.encoder, 'embeddings') and hasattr(self.encoder.embeddings, 'token_type_embeddings'):
61
+ if self.encoder.embeddings.token_type_embeddings.weight.shape[0] != 4:
62
+ old_embeddings = self.encoder.embeddings.token_type_embeddings
63
+ new_embeddings = nn.Embedding(4, old_embeddings.embedding_dim)
64
+ new_embeddings.weight.data[0] = old_embeddings.weight.data[0]
65
+ new_embeddings.weight.data[1:].normal_(mean=0.0, std=config.initializer_range)
66
+ self.encoder.embeddings.token_type_embeddings = new_embeddings
67
+ if hasattr(self.encoder, 'config'):
68
+ self.encoder.config.type_vocab_size = 4
69
+
70
  self.source_qa_model = Source_QA_Model
71
  self.source_qa_tokenizer = source_qa_tokenizer
72
  self.melbert_tokenizer = melbert_tokenizer
 
93
  self.source_id2label = {}
94
  try:
95
  import json
96
+ import os
97
+ # Try multiple paths
98
+ possible_paths = [
99
+ 'source_labels.json', # Same directory as model file
100
+ 'source_finder/source_labels.json', # Original location
101
+ os.path.join(os.path.dirname(__file__), 'source_labels.json'), # Next to this file
102
+ ]
103
+
104
+ for path in possible_paths:
105
+ try:
106
+ with open(path, 'r') as f:
107
+ source_label2id = json.load(f)
108
+ self.source_id2label = {v: k for k, v in source_label2id.items()}
109
+ print(f"✓ Loaded {len(self.source_id2label)} source domain labels from {path}")
110
+ break
111
+ except:
112
+ continue
113
+
114
+ if not self.source_id2label:
115
+ print(f"❌ Warning: Could not load source labels from any location")
116
  except Exception as e:
117
  print(f"❌ Warning: Could not load source labels: {e}")
118