|
import matplotlib.pyplot as plt |
|
import time |
|
import torch |
|
from torch.utils.data import DataLoader |
|
from torchvision import datasets, transforms |
|
import numpy as np |
|
import tracemalloc |
|
|
|
|
|
from Andromeda.model import Andromeda |
|
from Andromeda.utils.stable_adamw import StableAdamWUnfused |
|
|
|
torch.manual_seed(0) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed(0) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
|
import torch.nn.functional as F |
|
from nltk.translate.bleu_score import corpus_bleu |
|
from rouge import Rouge |
|
from sklearn.metrics import f1_score |
|
|
|
|
|
class AccuracyMetrics: |
|
def __init__(self): |
|
self.rouge = Rouge() |
|
|
|
def calculate_perplexity(self, model, data_loader): |
|
model.eval() |
|
total_loss = 0 |
|
with torch.no_grad(): |
|
for batch in data_loader: |
|
input_ids, labels = batch |
|
output = model(input_ids) |
|
loss = F.cross_entropy(output.view(-1, output.size(-1)), labels.view(-1)) |
|
total_loss += loss.item() |
|
return torch.exp(torch.tensor(total_loss / len(data_loader))) |
|
|
|
def calculate_bleu(self, references, hypotheses): |
|
return corpus_bleu(references, hypotheses) |
|
|
|
def calculate_rouge(self, references, hypotheses): |
|
scores = self.rouge.get_scores(hypotheses, references, avg=True) |
|
return scores |
|
|
|
def calculate_f1(self, true_labels, pred_labels): |
|
return f1_score(true_labels, pred_labels, average="weighted") |
|
|
|
|
|
|
|
|
|
|
|
|
|
test_dataset = datasets.FakeData(size=1000, transform=transforms.ToTensor()) |
|
|
|
|
|
model = Andromeda( |
|
num_tokens=50304, |
|
dim=1024, |
|
depth=24, |
|
dim_head=128, |
|
heads=8, |
|
alibi_num_heads=4 |
|
) |
|
|
|
|
|
|
|
|
|
accuracy_metrics = AccuracyMetrics() |
|
|
|
|
|
perplexity = accuracy_metrics.calculate_perplexity(model, data_loader) |
|
print('Perplexity:', perplexity) |
|
|
|
|
|
bleu = accuracy_metrics.calculate_bleu(references, hypotheses) |
|
print('BLEU Score:', bleu) |
|
|
|
|
|
rouge_scores = accuracy_metrics.calculate_rouge(references, hypotheses) |
|
print('ROUGE Scores:', rouge_scores) |
|
|
|
|
|
f1 = accuracy_metrics.calculate_f1(true_labels, pred_labels) |
|
print('F1 Score:', f1) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
AccuracyMetrics() |