Chat / chat.py
xhafaaldi's picture
Create chat.py
11003f0
import torch
from torch import nn
from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments
class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.get("labels")
outputs = model(**inputs)
logits = outputs.get("logits")
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0]))
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
return (loss, outputs) if return_outputs else loss
# Load pre-trained model
model = BertForSequenceClassification.from_pretrained("bert-base-uncased")
# Load tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# Tokenize your dataset (you would need to define this yourself)
# This is a placeholder and will not run
train_encodings = tokenizer(train_texts, truncation=True, padding=True)
train_labels = torch.tensor(train_labels)
# Define a PyTorch Dataset from the encodings and the labels
class MyDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
# Create a Dataset object
train_dataset = MyDataset(train_encodings, train_labels)
# Define training arguments
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=3,
per_device_train_batch_size=16,
warmup_steps=500,
weight_decay=0.01,
)
# Initialize the trainer
trainer = CustomTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
)
# Train the model
trainer.train()