Tamás Ficsor
commited on
Commit
•
21d6318
1
Parent(s):
11cdb73
add model
Browse files- modeling_charmen.py +35 -14
modeling_charmen.py
CHANGED
@@ -26,7 +26,7 @@ class CharmenElectraModelOutput(ModelOutput):
|
|
26 |
class CharmenElectraModel(ElectraPreTrainedModel):
|
27 |
config_class = CharmenElectraConfig
|
28 |
|
29 |
-
def __init__(self, config: CharmenElectraConfig, compatibility_with_transformers=False):
|
30 |
super().__init__(config)
|
31 |
self.embeddings: GBST = GBST(
|
32 |
num_tokens=config.vocab_size,
|
@@ -178,16 +178,20 @@ class CharmenElectraModel(ElectraPreTrainedModel):
|
|
178 |
prefix = "discriminator.electra."
|
179 |
|
180 |
for key, value in state_dict.items():
|
|
|
|
|
181 |
if key.startswith(prefix):
|
182 |
model[key[len(prefix):]] = value
|
|
|
|
|
183 |
|
184 |
-
super(CharmenElectraModel, self).load_state_dict(model, strict)
|
185 |
|
186 |
|
187 |
class CharmenElectraClassificationHead(nn.Module):
|
188 |
"""Head for sentence-level classification tasks."""
|
189 |
|
190 |
-
def __init__(self, config: CharmenElectraConfig):
|
191 |
super().__init__()
|
192 |
self.config = config
|
193 |
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
@@ -211,7 +215,7 @@ class CharmenElectraClassificationHead(nn.Module):
|
|
211 |
class CharmenElectraForSequenceClassification(ElectraForSequenceClassification):
|
212 |
config_class = CharmenElectraConfig
|
213 |
|
214 |
-
def __init__(self, config: CharmenElectraConfig, class_weight=None, label_smoothing=0.0):
|
215 |
super().__init__(config)
|
216 |
|
217 |
self.num_labels = config.num_labels
|
@@ -252,17 +256,26 @@ class CharmenElectraForSequenceClassification(ElectraForSequenceClassification):
|
|
252 |
|
253 |
def load_state_dict(self, state_dict: OrderDictType[str, Tensor], strict: bool = True):
|
254 |
model = OrderedDict()
|
255 |
-
prefix = "discriminator.
|
256 |
|
257 |
for key, value in state_dict.items():
|
|
|
|
|
258 |
if key.startswith(prefix):
|
|
|
|
|
259 |
model[key[len(prefix):]] = value
|
|
|
|
|
|
|
|
|
260 |
|
261 |
-
self.model.load_state_dict(state_dict=model, strict=
|
|
|
262 |
|
263 |
|
264 |
class CharmenElectraForTokenClassification(ElectraForTokenClassification):
|
265 |
-
def __init__(self, config: CharmenElectraConfig, class_weight=None, label_smoothing=0.0):
|
266 |
super().__init__(config)
|
267 |
|
268 |
self.num_labels = config.num_labels
|
@@ -317,13 +330,17 @@ class CharmenElectraForTokenClassification(ElectraForTokenClassification):
|
|
317 |
|
318 |
def load_state_dict(self, state_dict: OrderDictType[str, Tensor], strict: bool = True):
|
319 |
model = OrderedDict()
|
320 |
-
prefix = "discriminator.
|
321 |
|
322 |
for key, value in state_dict.items():
|
|
|
|
|
323 |
if key.startswith(prefix):
|
324 |
-
model[key[len(prefix):]] = value
|
|
|
|
|
325 |
|
326 |
-
self.
|
327 |
|
328 |
|
329 |
class Pooler(nn.Module):
|
@@ -342,7 +359,7 @@ class Pooler(nn.Module):
|
|
342 |
|
343 |
|
344 |
class CharmenElectraForMultipleChoice(ElectraForMultipleChoice):
|
345 |
-
def __init__(self, config: CharmenElectraConfig, class_weight=None, label_smoothing=0.0):
|
346 |
super().__init__(config)
|
347 |
self.num_labels = config.num_labels
|
348 |
self.config = config
|
@@ -401,10 +418,14 @@ class CharmenElectraForMultipleChoice(ElectraForMultipleChoice):
|
|
401 |
|
402 |
def load_state_dict(self, state_dict: OrderDictType[str, Tensor], strict: bool = True):
|
403 |
model = OrderedDict()
|
404 |
-
prefix = "discriminator.
|
405 |
|
406 |
for key, value in state_dict.items():
|
|
|
|
|
407 |
if key.startswith(prefix):
|
408 |
-
model[key[len(prefix):]] = value
|
|
|
|
|
409 |
|
410 |
-
self.
|
|
|
26 |
class CharmenElectraModel(ElectraPreTrainedModel):
|
27 |
config_class = CharmenElectraConfig
|
28 |
|
29 |
+
def __init__(self, config: CharmenElectraConfig, compatibility_with_transformers=False, **kwargs):
|
30 |
super().__init__(config)
|
31 |
self.embeddings: GBST = GBST(
|
32 |
num_tokens=config.vocab_size,
|
|
|
178 |
prefix = "discriminator.electra."
|
179 |
|
180 |
for key, value in state_dict.items():
|
181 |
+
if key.startswith('generator'):
|
182 |
+
continue
|
183 |
if key.startswith(prefix):
|
184 |
model[key[len(prefix):]] = value
|
185 |
+
else:
|
186 |
+
continue
|
187 |
|
188 |
+
super(CharmenElectraModel, self).load_state_dict(state_dict=model, strict=strict)
|
189 |
|
190 |
|
191 |
class CharmenElectraClassificationHead(nn.Module):
|
192 |
"""Head for sentence-level classification tasks."""
|
193 |
|
194 |
+
def __init__(self, config: CharmenElectraConfig, **kwargs):
|
195 |
super().__init__()
|
196 |
self.config = config
|
197 |
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
|
215 |
class CharmenElectraForSequenceClassification(ElectraForSequenceClassification):
|
216 |
config_class = CharmenElectraConfig
|
217 |
|
218 |
+
def __init__(self, config: CharmenElectraConfig, class_weight=None, label_smoothing=0.0, **kwargs):
|
219 |
super().__init__(config)
|
220 |
|
221 |
self.num_labels = config.num_labels
|
|
|
256 |
|
257 |
def load_state_dict(self, state_dict: OrderDictType[str, Tensor], strict: bool = True):
|
258 |
model = OrderedDict()
|
259 |
+
prefix = "discriminator.model"
|
260 |
|
261 |
for key, value in state_dict.items():
|
262 |
+
if key.startswith('generator'):
|
263 |
+
continue
|
264 |
if key.startswith(prefix):
|
265 |
+
if 'discriminator_predictions' in key:
|
266 |
+
continue
|
267 |
model[key[len(prefix):]] = value
|
268 |
+
else:
|
269 |
+
if key.startswith('sop'):
|
270 |
+
continue
|
271 |
+
model[key] = value
|
272 |
|
273 |
+
self.model.load_state_dict(state_dict=model, strict=False)
|
274 |
+
self.classifier.load_state_dict(state_dict=model, strict=False)
|
275 |
|
276 |
|
277 |
class CharmenElectraForTokenClassification(ElectraForTokenClassification):
|
278 |
+
def __init__(self, config: CharmenElectraConfig, class_weight=None, label_smoothing=0.0, **kwargs):
|
279 |
super().__init__(config)
|
280 |
|
281 |
self.num_labels = config.num_labels
|
|
|
330 |
|
331 |
def load_state_dict(self, state_dict: OrderDictType[str, Tensor], strict: bool = True):
|
332 |
model = OrderedDict()
|
333 |
+
prefix = "discriminator."
|
334 |
|
335 |
for key, value in state_dict.items():
|
336 |
+
if key.startswith('generator'):
|
337 |
+
continue
|
338 |
if key.startswith(prefix):
|
339 |
+
model[key[len(prefix):].replace('electra', 'model')] = value
|
340 |
+
else:
|
341 |
+
model[key] = value
|
342 |
|
343 |
+
super(CharmenElectraForTokenClassification, self).load_state_dict(state_dict=model, strict=strict)
|
344 |
|
345 |
|
346 |
class Pooler(nn.Module):
|
|
|
359 |
|
360 |
|
361 |
class CharmenElectraForMultipleChoice(ElectraForMultipleChoice):
|
362 |
+
def __init__(self, config: CharmenElectraConfig, class_weight=None, label_smoothing=0.0, **kwargs):
|
363 |
super().__init__(config)
|
364 |
self.num_labels = config.num_labels
|
365 |
self.config = config
|
|
|
418 |
|
419 |
def load_state_dict(self, state_dict: OrderDictType[str, Tensor], strict: bool = True):
|
420 |
model = OrderedDict()
|
421 |
+
prefix = "discriminator."
|
422 |
|
423 |
for key, value in state_dict.items():
|
424 |
+
if key.startswith('generator'):
|
425 |
+
continue
|
426 |
if key.startswith(prefix):
|
427 |
+
model[key[len(prefix):].replace('electra', 'model')] = value
|
428 |
+
else:
|
429 |
+
model[key] = value
|
430 |
|
431 |
+
super(CharmenElectraForMultipleChoice, self).load_state_dict(state_dict=model, strict=strict)
|