Clemet's picture
Upload 9 files
049e137 verified
raw
history blame contribute delete
No virus
1.58 kB
from safetensors.torch import load_model
from transformers import RobertaTokenizer, AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from transformers import GPT2TokenizerFast, GPT2ForSequenceClassification, Trainer, TrainingArguments, DataCollatorWithPadding
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, Trainer, TrainingArguments, DataCollatorWithPadding
import numpy as np
def get_roberta():
model = AutoModelForSequenceClassification.from_pretrained('cross-encoder/nli-roberta-base', output_attentions=True)
load_model(model, "roberta.safetensors")
tokenizer = AutoTokenizer.from_pretrained('cross-encoder/nli-roberta-base')
return tokenizer, model
def get_gpt():
model = GPT2ForSequenceClassification.from_pretrained('gpt2', num_labels=3, output_attentions=True)
model.config.pad_token_id = model.config.eos_token_id
load_model(model, "gpt.safetensors")
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
return tokenizer, model
def get_distilbert():
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=3, output_attentions=True)
load_model(model, "distilbert.safetensors")
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
return tokenizer, model
def softmax(xx):
"""Compute softmax values for each sets of scores in x."""
x = xx.detach().numpy()[0]
return np.exp(x) / np.sum(np.exp(x), axis=0)