LunaAi / training /train.py
faisaldadkhan13's picture
Upload 7 files
226aae3 verified
# training/train.py
import os
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, AdamW
from model.luna_model import LunaAI
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
class TextDataset(Dataset):
def __init__(self, csv_file, tokenizer, max_length=128):
self.data = pd.read_csv(csv_file)
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
text = self.data.iloc[idx, 0]
label = self.data.iloc[idx, 1]
encoding = self.tokenizer.encode_plus(
text,
add_special_tokens=True,
return_tensors='pt',
padding='max_length',
max_length=self.max_length,
truncation=True,
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': torch.tensor(label, dtype=torch.long),
}
def evaluate_model(model, dataloader):
model.eval()
predictions, true_labels = [], []
with torch.no_grad():
for batch in dataloader:
outputs = model(batch['input_ids'], batch['attention_mask'])
_, preds = torch.max(outputs, dim=1)
predictions.extend(preds.cpu().numpy())
true_labels.extend(batch['labels'].cpu().numpy())
accuracy = accuracy_score(true_labels, predictions)
precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predictions, average='weighted')
return accuracy, precision, recall, f1
def save_checkpoint(epoch, model, optimizer, loss, path="./checkpoints"):
os.makedirs(path, exist_ok=True)
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, os.path.join(path, f"checkpoint_epoch_{epoch}.pth"))
def train_model(model, dataset, epochs=3, batch_size=16, learning_rate=5e-5):
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
optimizer = AdamW(model.parameters(), lr=learning_rate)
model.train()
for epoch in range(epochs):
for batch in dataloader:
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']
optimizer.zero_grad()
outputs = model(input_ids, attention_mask)
loss = nn.CrossEntropyLoss()(outputs, labels)
loss.backward()
optimizer.step()
print(f'Epoch {epoch}, Loss: {loss.item()}')
save_checkpoint(epoch, model, optimizer, loss.item())
# Optional: Evaluate the model at each epoch end
accuracy, precision, recall, f1 = evaluate_model(model, dataloader)
print(f'Epoch {epoch} - Accuracy: {accuracy}, Precision: {precision}, Recall: {recall}, F1 Score: {f1}')
if __name__ == "__main__":
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
dataset = TextDataset('data/dataset.csv', tokenizer)
model = LunaAI(num_classes=2) # Adjust num_classes if necessary
train_model(model, dataset)