Spaces:
Runtime error
Runtime error
# This is a heavily adapted version of this notebook: | |
# https://github.com/huggingface/notebooks/blob/main/examples/text_classification.ipynb , | |
# where we show on a simple text classification problem how we can integrate | |
# components for uncertainty quantification into large pretrained models. | |
import evaluate | |
import numpy as np | |
from datasets import load_dataset | |
from transformers import ( | |
AutoTokenizer, | |
TrainingArguments, | |
Trainer, | |
TrainerCallback, | |
) | |
from uq import BertForUQSequenceClassification | |
BATCH_SIZE = 16 | |
EVAL_BATCH_SIZE = 128 | |
DEVICE = "cpu" | |
# cola dataset for determining whether sentences are gramatically correct | |
task = "cola" | |
model_checkpoint = "bert-base-uncased" | |
dataset = load_dataset("glue", task) | |
metric = evaluate.load("glue", task) | |
# Load our tokenizer and tokenize our data as it streams in | |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True) | |
def tokenize_data(data): | |
# Will add input ID and attention mask columns to dataset | |
return tokenizer(data["sentence"], truncation=True) | |
encoded_dataset = dataset.map(tokenize_data, batched=True) | |
# Now we can load our pretrained model and introduce our uncertainty quantification component, | |
# which in this case is a GP final layer without any spectral normalization of the transformer weights | |
num_labels = 2 | |
id2label = {0: "Invalid", 1: "Valid"} | |
label2id = {val: key for key, val in id2label.items()} | |
model = BertForUQSequenceClassification.from_pretrained( | |
model_checkpoint, num_labels=num_labels, id2label=id2label, label2id=label2id | |
) | |
# Specify training arguments | |
metric_name = "matthews_correlation" | |
model_name = model_checkpoint.split("/")[-1] | |
args = TrainingArguments( | |
f"{model_name}-finetuned-{task}", | |
evaluation_strategy="epoch", | |
save_strategy="epoch", | |
learning_rate=2e-5, | |
per_device_train_batch_size=BATCH_SIZE, | |
per_device_eval_batch_size=EVAL_BATCH_SIZE, | |
num_train_epochs=3, | |
weight_decay=0.01, | |
load_best_model_at_end=True, | |
metric_for_best_model=metric_name, | |
push_to_hub=True, | |
use_mps_device=False, | |
no_cuda=True, | |
) | |
# Set up metric tracking | |
def compute_metrics(eval_predictions): | |
predictions, labels = eval_predictions | |
predictions = np.argmax(predictions, axis=1) | |
return metric.compute(predictions=predictions, references=labels) | |
# Finally, set up trainer for finetuning the model | |
model.to(DEVICE) | |
trainer = Trainer( | |
model, | |
args, | |
train_dataset=encoded_dataset["train"], | |
eval_dataset=encoded_dataset["validation"], | |
tokenizer=tokenizer, | |
compute_metrics=compute_metrics, | |
) | |
# Add in a callback to reset the covariance matrix after each epoch, as we only need | |
# to do this once at the final epoch, so we don't double count any of the data. We | |
# could use a more elegant solution, but the covariance computation is very cheap | |
# so doing it ~5 times rather than once isn't a big deal. | |
class ResetCovarianceCallback(TrainerCallback): | |
def __init__(self, trainer) -> None: | |
super().__init__() | |
self._trainer = trainer | |
def on_epoch_end(self, args, state, control, **kwargs): | |
if control.should_evaluate: | |
self._trainer.model.classifier.reset_cov() | |
trainer.add_callback(ResetCovarianceCallback(trainer)) | |
trainer.train() | |
trainer.push_to_hub() | |