xhafaaldi commited on
Commit
11003f0
1 Parent(s): dc70156

Create chat.py

Browse files
Files changed (1) hide show
  1. chat.py +59 -0
chat.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments
4
+
5
+ class CustomTrainer(Trainer):
6
+ def compute_loss(self, model, inputs, return_outputs=False):
7
+ labels = inputs.get("labels")
8
+ outputs = model(**inputs)
9
+ logits = outputs.get("logits")
10
+ loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0]))
11
+ loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
12
+ return (loss, outputs) if return_outputs else loss
13
+
14
+ # Load pre-trained model
15
+ model = BertForSequenceClassification.from_pretrained("bert-base-uncased")
16
+
17
+ # Load tokenizer
18
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
19
+
20
+ # Tokenize your dataset (you would need to define this yourself)
21
+ # This is a placeholder and will not run
22
+ train_encodings = tokenizer(train_texts, truncation=True, padding=True)
23
+ train_labels = torch.tensor(train_labels)
24
+
25
+ # Define a PyTorch Dataset from the encodings and the labels
26
+ class MyDataset(torch.utils.data.Dataset):
27
+ def __init__(self, encodings, labels):
28
+ self.encodings = encodings
29
+ self.labels = labels
30
+
31
+ def __getitem__(self, idx):
32
+ item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
33
+ item['labels'] = torch.tensor(self.labels[idx])
34
+ return item
35
+
36
+ def __len__(self):
37
+ return len(self.labels)
38
+
39
+ # Create a Dataset object
40
+ train_dataset = MyDataset(train_encodings, train_labels)
41
+
42
+ # Define training arguments
43
+ training_args = TrainingArguments(
44
+ output_dir='./results',
45
+ num_train_epochs=3,
46
+ per_device_train_batch_size=16,
47
+ warmup_steps=500,
48
+ weight_decay=0.01,
49
+ )
50
+
51
+ # Initialize the trainer
52
+ trainer = CustomTrainer(
53
+ model=model,
54
+ args=training_args,
55
+ train_dataset=train_dataset,
56
+ )
57
+
58
+ # Train the model
59
+ trainer.train()