|
import torch |
|
import glob |
|
import os |
|
from transformers import BertTokenizerFast as BertTokenizer, BertForSequenceClassification |
|
|
|
|
|
LABEL_COLUMNS = ["Assertive Tone", "Conversational Tone", "Emotional Tone", "Informative Tone", "None"] |
|
|
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") |
|
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=5) |
|
id2label = {i:label for i,label in enumerate(LABEL_COLUMNS)} |
|
label2id = {label:i for i,label in enumerate(LABEL_COLUMNS)} |
|
|
|
for ckpt in glob.glob('checkpoints/*.ckpt'): |
|
base_name = os.path.basename(ckpt) |
|
|
|
model_name = os.path.splitext(base_name)[0] |
|
params = torch.load(ckpt, map_location="cpu")['state_dict'] |
|
msg = model.load_state_dict(params, strict=True) |
|
path = f'models/{model_name}' |
|
os.makedirs(path, exist_ok=True) |
|
|
|
torch.save(model.state_dict(), f'{path}/pytorch_model.bin') |
|
config = model.config |
|
config.architectures = ['BertForSequenceClassification'] |
|
config.label2id = label2id |
|
config.id2label = id2label |
|
model.config.to_json_file(f'{path}/config.json') |
|
tokenizer.save_vocabulary(path) |