Clemet's picture
Upload utils.py
e79483d verified
raw
history blame
No virus
1.33 kB
from safetensors.torch import load_model
from transformers import RobertaTokenizer, AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from transformers import GPT2TokenizerFast, GPT2ForSequenceClassification, Trainer, TrainingArguments, DataCollatorWithPadding
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, Trainer, TrainingArguments, DataCollatorWithPadding
def get_roberta():
model = AutoModelForSequenceClassification.from_pretrained('cross-encoder/nli-roberta-base')
load_model(model, "roberta.safetensors")
tokenizer = AutoTokenizer.from_pretrained('cross-encoder/nli-roberta-base')
return tokenizer, model
def get_gpt():
model = GPT2ForSequenceClassification.from_pretrained('gpt2', num_labels=3)
model.config.pad_token_id = model.config.eos_token_id
load_model(model, "gpt.safetensors")
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
return tokenizer, model
def get_distilbert():
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=3)
load_model(model, "distilbert.safetensors")
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
return tokenizer, model