|
import torch |
|
from torch import nn |
|
from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments |
|
|
|
class CustomTrainer(Trainer): |
|
def compute_loss(self, model, inputs, return_outputs=False): |
|
labels = inputs.get("labels") |
|
outputs = model(**inputs) |
|
logits = outputs.get("logits") |
|
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0])) |
|
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1)) |
|
return (loss, outputs) if return_outputs else loss |
|
|
|
|
|
model = BertForSequenceClassification.from_pretrained("bert-base-uncased") |
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") |
|
|
|
|
|
|
|
train_encodings = tokenizer(train_texts, truncation=True, padding=True) |
|
train_labels = torch.tensor(train_labels) |
|
|
|
|
|
class MyDataset(torch.utils.data.Dataset): |
|
def __init__(self, encodings, labels): |
|
self.encodings = encodings |
|
self.labels = labels |
|
|
|
def __getitem__(self, idx): |
|
item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()} |
|
item['labels'] = torch.tensor(self.labels[idx]) |
|
return item |
|
|
|
def __len__(self): |
|
return len(self.labels) |
|
|
|
|
|
train_dataset = MyDataset(train_encodings, train_labels) |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir='./results', |
|
num_train_epochs=3, |
|
per_device_train_batch_size=16, |
|
warmup_steps=500, |
|
weight_decay=0.01, |
|
) |
|
|
|
|
|
trainer = CustomTrainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_dataset, |
|
) |
|
|
|
|
|
trainer.train() |
|
|