End of training
Browse files- DisamBertCrossEncoder.py +125 -0
- README.md +19 -19
- config.json +3 -0
- model.safetensors +1 -1
- tokenizer.json +8 -1
DisamBertCrossEncoder.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections.abc import Generator, Iterable
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from enum import StrEnum
|
| 4 |
+
from itertools import chain
|
| 5 |
+
|
| 6 |
+
from nltk.corpus import wordnet
|
| 7 |
+
from nltk.metrics import edit_distance
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from transformers import (
|
| 13 |
+
AutoConfig,
|
| 14 |
+
AutoModel,
|
| 15 |
+
AutoTokenizer,
|
| 16 |
+
ModernBertModel,
|
| 17 |
+
PreTrainedConfig,
|
| 18 |
+
PreTrainedModel,
|
| 19 |
+
)
|
| 20 |
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
| 21 |
+
|
| 22 |
+
BATCH_SIZE = 16
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ModelURI(StrEnum):
|
| 26 |
+
BASE = "answerdotai/ModernBERT-base"
|
| 27 |
+
LARGE = "answerdotai/ModernBERT-large"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass(slots=True, frozen=True)
|
| 31 |
+
class LexicalExample:
|
| 32 |
+
concept: str
|
| 33 |
+
definition: str
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass(slots=True, frozen=True)
|
| 37 |
+
class PaddedBatch:
|
| 38 |
+
input_ids: torch.Tensor
|
| 39 |
+
attention_mask: torch.Tensor
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class DisamBertCrossEncoder(PreTrainedModel):
|
| 43 |
+
def __init__(self, config: PreTrainedConfig):
|
| 44 |
+
super().__init__(config)
|
| 45 |
+
if config.init_basemodel:
|
| 46 |
+
self.BaseModel = AutoModel.from_pretrained(config.name_or_path, device_map="auto")
|
| 47 |
+
else:
|
| 48 |
+
self.BaseModel = ModernBertModel(config)
|
| 49 |
+
config.init_basemodel = False
|
| 50 |
+
self.loss = nn.BCEWithLogitsLoss()
|
| 51 |
+
self.post_init()
|
| 52 |
+
|
| 53 |
+
@classmethod
|
| 54 |
+
def from_base(cls, base_id: ModelURI):
|
| 55 |
+
config = AutoConfig.from_pretrained(base_id)
|
| 56 |
+
config.init_basemodel = True
|
| 57 |
+
config.tokenizer_path = base_id
|
| 58 |
+
return cls(config)
|
| 59 |
+
|
| 60 |
+
def forward(
|
| 61 |
+
self,
|
| 62 |
+
input_ids: torch.Tensor,
|
| 63 |
+
attention_mask: torch.Tensor,
|
| 64 |
+
labels: torch.Tensor | None = None,
|
| 65 |
+
output_hidden_states: bool = False,
|
| 66 |
+
output_attentions: bool = False,
|
| 67 |
+
) -> SequenceClassifierOutput:
|
| 68 |
+
base_model_output = self.BaseModel(
|
| 69 |
+
input_ids,
|
| 70 |
+
attention_mask,
|
| 71 |
+
output_hidden_states=output_hidden_states,
|
| 72 |
+
output_attentions=output_attentions,
|
| 73 |
+
)
|
| 74 |
+
token_vectors = base_model_output.last_hidden_state
|
| 75 |
+
prev = -1
|
| 76 |
+
rows = []
|
| 77 |
+
cols = []
|
| 78 |
+
for (i,j) in (input_ids == self.config.sep_token_id).nonzero():
|
| 79 |
+
if i!=prev:
|
| 80 |
+
rows.append(i)
|
| 81 |
+
cols.append(j)
|
| 82 |
+
prev=i
|
| 83 |
+
gloss_vectors = token_vectors[rows,cols]
|
| 84 |
+
logits = torch.einsum("ij,ij->i",token_vectors[:,0],gloss_vectors)
|
| 85 |
+
return SequenceClassifierOutput(
|
| 86 |
+
logits=logits,
|
| 87 |
+
loss=self.loss(logits, labels) if labels is not None else None,
|
| 88 |
+
hidden_states=base_model_output.hidden_states if output_hidden_states else None,
|
| 89 |
+
attentions=base_model_output.attentions if output_attentions else None,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
def get_lemma(text: str, synset: wordnet.synset) -> wordnet.lemma:
|
| 93 |
+
best_score = 1000000
|
| 94 |
+
best_lemma = None
|
| 95 |
+
for lemma in synset.lemmas():
|
| 96 |
+
score = edit_distance(text, lemma.name())
|
| 97 |
+
if score < best_score:
|
| 98 |
+
best_score = score
|
| 99 |
+
best_lemma = lemma
|
| 100 |
+
return best_lemma
|
| 101 |
+
|
| 102 |
+
class CrossEncoderTagger:
|
| 103 |
+
def __init__(self,url:str):
|
| 104 |
+
self.model=AutoModel.from_pretrained(url,
|
| 105 |
+
device_map="auto",
|
| 106 |
+
trust_remote_code=True)
|
| 107 |
+
print(self.model)
|
| 108 |
+
self.tokenizer=AutoTokenizer.from_pretrained(url)
|
| 109 |
+
|
| 110 |
+
def __call__(self,target:str,sentence:str,candidates:str)->str:
|
| 111 |
+
text = f"{target}::{sentence}"
|
| 112 |
+
synsets = [wordnet.synset(candidate) for candidate in candidates]
|
| 113 |
+
definitions = [f"{get_lemma(target,synset)}::{synset.definition()}"
|
| 114 |
+
for synset in synsets]
|
| 115 |
+
sentences = [text]*len(candidates)
|
| 116 |
+
with self.model.device:
|
| 117 |
+
tokens = self.tokenizer(sentences,definitions,padding=True,return_tensors="pt")
|
| 118 |
+
output = self.model(tokens.input_ids,
|
| 119 |
+
tokens.attention_mask)
|
| 120 |
+
print(dir(output))
|
| 121 |
+
logits = output.logits
|
| 122 |
+
return candidates[logits.argmax()]
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
|
README.md
CHANGED
|
@@ -24,12 +24,12 @@ should probably proofread and complete it, then remove this comment. -->
|
|
| 24 |
|
| 25 |
This model is a fine-tuned version of [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) on the None dataset.
|
| 26 |
It achieves the following results on the evaluation set:
|
| 27 |
-
- Loss: 0.
|
| 28 |
-
- Precision: 0.
|
| 29 |
-
- Recall: 0.
|
| 30 |
-
- F1: 0.
|
| 31 |
-
- Accuracy: 0.
|
| 32 |
-
- Matthews Correlation: 0.
|
| 33 |
|
| 34 |
## Model description
|
| 35 |
|
|
@@ -60,19 +60,19 @@ The following hyperparameters were used during training:
|
|
| 60 |
|
| 61 |
### Training results
|
| 62 |
|
| 63 |
-
| Training Loss | Epoch | Step
|
| 64 |
-
|:-------------:|:-----:|:-----
|
| 65 |
-
| No log | 0 | 0
|
| 66 |
-
| 0.
|
| 67 |
-
| 0.
|
| 68 |
-
| 0.
|
| 69 |
-
| 0.
|
| 70 |
-
| 0.
|
| 71 |
-
| 0.
|
| 72 |
-
| 0.
|
| 73 |
-
| 0.
|
| 74 |
-
| 0.
|
| 75 |
-
| 0.
|
| 76 |
|
| 77 |
|
| 78 |
### Framework versions
|
|
|
|
| 24 |
|
| 25 |
This model is a fine-tuned version of [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) on the None dataset.
|
| 26 |
It achieves the following results on the evaluation set:
|
| 27 |
+
- Loss: 0.3160
|
| 28 |
+
- Precision: 0.6783
|
| 29 |
+
- Recall: 0.5978
|
| 30 |
+
- F1: 0.6355
|
| 31 |
+
- Accuracy: 0.9378
|
| 32 |
+
- Matthews Correlation: 0.6031
|
| 33 |
|
| 34 |
## Model description
|
| 35 |
|
|
|
|
| 60 |
|
| 61 |
### Training results
|
| 62 |
|
| 63 |
+
| Training Loss | Epoch | Step | Validation Loss | Precision | Recall | F1 | Accuracy | Matthews Correlation |
|
| 64 |
+
|:-------------:|:-----:|:-----:|:---------------:|:---------:|:------:|:------:|:--------:|:--------------------:|
|
| 65 |
+
| No log | 0 | 0 | 1123.2456 | 0.0907 | 1.0 | 0.1663 | 0.0909 | 0.0045 |
|
| 66 |
+
| 0.1943 | 1.0 | 9050 | 0.1832 | 0.7346 | 0.2615 | 0.3857 | 0.9245 | 0.4096 |
|
| 67 |
+
| 0.1500 | 2.0 | 18100 | 0.1551 | 0.7019 | 0.4967 | 0.5817 | 0.9352 | 0.5574 |
|
| 68 |
+
| 0.1242 | 3.0 | 27150 | 0.1481 | 0.7381 | 0.5451 | 0.6271 | 0.9412 | 0.6040 |
|
| 69 |
+
| 0.1017 | 4.0 | 36200 | 0.1482 | 0.7413 | 0.5604 | 0.6383 | 0.9424 | 0.6147 |
|
| 70 |
+
| 0.0774 | 5.0 | 45250 | 0.1564 | 0.7179 | 0.6154 | 0.6627 | 0.9432 | 0.6342 |
|
| 71 |
+
| 0.0610 | 6.0 | 54300 | 0.1859 | 0.7579 | 0.5297 | 0.6235 | 0.9420 | 0.6044 |
|
| 72 |
+
| 0.0434 | 7.0 | 63350 | 0.2016 | 0.6754 | 0.6264 | 0.6499 | 0.9388 | 0.6170 |
|
| 73 |
+
| 0.0298 | 8.0 | 72400 | 0.2480 | 0.6520 | 0.6505 | 0.6513 | 0.9368 | 0.6165 |
|
| 74 |
+
| 0.0216 | 9.0 | 81450 | 0.2961 | 0.6819 | 0.5890 | 0.6321 | 0.9378 | 0.6002 |
|
| 75 |
+
| 0.0174 | 10.0 | 90500 | 0.3160 | 0.6783 | 0.5978 | 0.6355 | 0.9378 | 0.6031 |
|
| 76 |
|
| 77 |
|
| 78 |
### Framework versions
|
config.json
CHANGED
|
@@ -4,6 +4,9 @@
|
|
| 4 |
],
|
| 5 |
"attention_bias": false,
|
| 6 |
"attention_dropout": 0.0,
|
|
|
|
|
|
|
|
|
|
| 7 |
"bos_token_id": null,
|
| 8 |
"classifier_activation": "gelu",
|
| 9 |
"classifier_bias": false,
|
|
|
|
| 4 |
],
|
| 5 |
"attention_bias": false,
|
| 6 |
"attention_dropout": 0.0,
|
| 7 |
+
"auto_map": {
|
| 8 |
+
"AutoModel": "DisamBertCrossEncoder.DisamBertCrossEncoder"
|
| 9 |
+
},
|
| 10 |
"bos_token_id": null,
|
| 11 |
"classifier_activation": "gelu",
|
| 12 |
"classifier_bias": false,
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 596071480
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2cb625d94dadd5a1929c852bb4728f74906eab0e8898e2300353ddaed125bb08
|
| 3 |
size 596071480
|
tokenizer.json
CHANGED
|
@@ -1,7 +1,14 @@
|
|
| 1 |
{
|
| 2 |
"version": "1.0",
|
| 3 |
"truncation": null,
|
| 4 |
-
"padding":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"added_tokens": [
|
| 6 |
{
|
| 7 |
"id": 0,
|
|
|
|
| 1 |
{
|
| 2 |
"version": "1.0",
|
| 3 |
"truncation": null,
|
| 4 |
+
"padding": {
|
| 5 |
+
"strategy": "BatchLongest",
|
| 6 |
+
"direction": "Right",
|
| 7 |
+
"pad_to_multiple_of": null,
|
| 8 |
+
"pad_id": 50283,
|
| 9 |
+
"pad_type_id": 0,
|
| 10 |
+
"pad_token": "[PAD]"
|
| 11 |
+
},
|
| 12 |
"added_tokens": [
|
| 13 |
{
|
| 14 |
"id": 0,
|