|
|
|
import os |
|
import pandas as pd |
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import DataLoader, Dataset |
|
from transformers import BertTokenizer, AdamW |
|
from model.luna_model import LunaAI |
|
from sklearn.metrics import accuracy_score, precision_recall_fscore_support |
|
|
|
class TextDataset(Dataset): |
|
def __init__(self, csv_file, tokenizer, max_length=128): |
|
self.data = pd.read_csv(csv_file) |
|
self.tokenizer = tokenizer |
|
self.max_length = max_length |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
text = self.data.iloc[idx, 0] |
|
label = self.data.iloc[idx, 1] |
|
encoding = self.tokenizer.encode_plus( |
|
text, |
|
add_special_tokens=True, |
|
return_tensors='pt', |
|
padding='max_length', |
|
max_length=self.max_length, |
|
truncation=True, |
|
) |
|
return { |
|
'input_ids': encoding['input_ids'].flatten(), |
|
'attention_mask': encoding['attention_mask'].flatten(), |
|
'labels': torch.tensor(label, dtype=torch.long), |
|
} |
|
|
|
def evaluate_model(model, dataloader): |
|
model.eval() |
|
predictions, true_labels = [], [] |
|
with torch.no_grad(): |
|
for batch in dataloader: |
|
outputs = model(batch['input_ids'], batch['attention_mask']) |
|
_, preds = torch.max(outputs, dim=1) |
|
predictions.extend(preds.cpu().numpy()) |
|
true_labels.extend(batch['labels'].cpu().numpy()) |
|
|
|
accuracy = accuracy_score(true_labels, predictions) |
|
precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predictions, average='weighted') |
|
return accuracy, precision, recall, f1 |
|
|
|
def save_checkpoint(epoch, model, optimizer, loss, path="./checkpoints"): |
|
os.makedirs(path, exist_ok=True) |
|
torch.save({ |
|
'epoch': epoch, |
|
'model_state_dict': model.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'loss': loss, |
|
}, os.path.join(path, f"checkpoint_epoch_{epoch}.pth")) |
|
|
|
def train_model(model, dataset, epochs=3, batch_size=16, learning_rate=5e-5): |
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
|
optimizer = AdamW(model.parameters(), lr=learning_rate) |
|
model.train() |
|
|
|
for epoch in range(epochs): |
|
for batch in dataloader: |
|
input_ids = batch['input_ids'] |
|
attention_mask = batch['attention_mask'] |
|
labels = batch['labels'] |
|
|
|
optimizer.zero_grad() |
|
outputs = model(input_ids, attention_mask) |
|
loss = nn.CrossEntropyLoss()(outputs, labels) |
|
loss.backward() |
|
optimizer.step() |
|
print(f'Epoch {epoch}, Loss: {loss.item()}') |
|
|
|
save_checkpoint(epoch, model, optimizer, loss.item()) |
|
|
|
|
|
accuracy, precision, recall, f1 = evaluate_model(model, dataloader) |
|
print(f'Epoch {epoch} - Accuracy: {accuracy}, Precision: {precision}, Recall: {recall}, F1 Score: {f1}') |
|
|
|
if __name__ == "__main__": |
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
dataset = TextDataset('data/dataset.csv', tokenizer) |
|
model = LunaAI(num_classes=2) |
|
train_model(model, dataset) |
|
|