| """BERT for sentence-pair boundary classification.""" |
|
|
| from transformers import ( |
| AutoTokenizer, |
| BertForSequenceClassification, |
| PreTrainedTokenizerFast, |
| ) |
|
|
| from src.datasets.combined_pairs_dataset import NUM_LABELS, ID2LABEL, LABEL2ID |
|
|
| BASE_MODEL = "bert-base-uncased" |
|
|
|
|
| def load_bert( |
| pretrained: str = BASE_MODEL, |
| ) -> BertForSequenceClassification: |
| """Instantiate BERT for 3-class sentence-pair classification.""" |
| return BertForSequenceClassification.from_pretrained( |
| pretrained, |
| num_labels=NUM_LABELS, |
| id2label=ID2LABEL, |
| label2id=LABEL2ID, |
| ) |
|
|
|
|
| def load_bert_tokenizer( |
| pretrained: str = BASE_MODEL, |
| ) -> PreTrainedTokenizerFast: |
| return AutoTokenizer.from_pretrained(pretrained, use_fast=True) |
|
|