|
from model import DocBERT |
|
from dataset import load_data, create_data_loaders |
|
from trainer import Trainer |
|
import argparse |
|
import os, sklearn |
|
import numpy as np |
|
import torch |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Document Classification with Distillation") |
|
parser.add_argument("--data_path", type=str, required=True, help="Path to the dataset") |
|
parser.add_argument("--bert_model", type=str, default="bert-base-uncased", help="Pre-trained BERT model name") |
|
parser.add_argument("--model_path", type=str, required=True, help="Path to the trained model") |
|
parser.add_argument("--max_seq_length", type=int, default=250, help="Maximum sequence length for BERT (e.g., 250 for PhoBERT as PhoBERT allows max_position_embeddings=258)") |
|
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for training and evaluation") |
|
parser.add_argument("--num_classes", type=int, required=True, help="Number of classes for classification") |
|
parser.add_argument("--text_column", type=str, default="text", help="Column name for text data") |
|
parser.add_argument("--label_column", type=str, nargs="+", help="Column name for labels") |
|
parser.add_argument("--class_names", type=str, nargs='+', required=False, help="List of class names for classification") |
|
parser.add_argument("--inference_batch_limit", type=int, default=-1, help="Limit for inference batch counts") |
|
parser.add_argument("--print_predictions", type=bool, default=False, help="Print predictions to console") |
|
parser.add_argument("--threshold", type=float, default=0.55, help="Threshold for classification") |
|
args = parser.parse_args() |
|
|
|
class_names = args.class_names |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
label_column = args.label_column[0] if isinstance(args.label_column, list) and len(args.label_column) == 1 else args.label_column |
|
num_categories = len(args.label_column) if isinstance(args.label_column, list) else 1 |
|
train_data, val_data, test_data = load_data(args.data_path, |
|
text_col=args.text_column, |
|
label_col=label_column, |
|
validation_split=0.0, |
|
test_split=1.0) |
|
train_loader, val_loader, test_loader = create_data_loaders(train_data=train_data, |
|
val_data=val_data, |
|
test_data=test_data, |
|
tokenizer_name=args.bert_model, |
|
batch_size=args.batch_size, |
|
max_length=args.max_seq_length, |
|
num_classes=args.num_classes) |
|
|
|
model = DocBERT(bert_model_name=args.bert_model, num_classes=args.num_classes, num_categories=num_categories) |
|
model.load_state_dict(torch.load(args.model_path, map_location=device)) |
|
model = model.to(device) |
|
|
|
all_labels = np.array([], dtype=int) |
|
all_predictions = np.array([], dtype=int) |
|
batch_window_index = 0 |
|
batch_size = args.batch_size |
|
|
|
|
|
for batch in test_loader: |
|
input_ids = batch['input_ids'] |
|
attention_mask = batch['attention_mask'] |
|
token_type_ids = batch['token_type_ids'] |
|
labels = batch['label'] |
|
|
|
input_ids = input_ids.to(device) |
|
attention_mask = attention_mask.to(device) |
|
token_type_ids = token_type_ids.to(device) |
|
labels = labels.to(device) |
|
all_labels = np.append(all_labels, labels.cpu().numpy()) |
|
|
|
with torch.no_grad(): |
|
outputs = model(input_ids, attention_mask=attention_mask) |
|
logits = outputs |
|
if num_categories > 1: |
|
batch_size, total_classes = outputs.shape |
|
if total_classes % num_categories != 0: |
|
raise ValueError(f"Error: Number of total classes in the batch must of divisible by {num_categories}") |
|
|
|
classes_per_group = total_classes // num_categories |
|
|
|
reshaped = outputs.view(outputs.size(0), -1, classes_per_group) |
|
probs = torch.softmax(reshaped, dim=1) |
|
probs = torch.where(probs > args.threshold, probs, 0.0) |
|
|
|
predictions = reshaped.argmax(dim=-1) |
|
else: |
|
predictions = torch.argmax(logits, dim=-1) |
|
|
|
all_predictions = np.append(all_predictions, predictions.cpu().numpy()) |
|
|
|
if args.print_predictions: |
|
for i in range(len(predictions)): |
|
idx = int(i) |
|
print(f"Text: {test_data[0][batch_window_index*batch_size + idx]}") |
|
print(f"True Label: {labels[idx].item()}, Predicted Label: {predictions[idx].item()}") |
|
print(f"Predicted Class: {class_names[predictions[idx].item() if len(class_names) > predictions[idx].item() else 'Unknown']}") |
|
print(f"True Class: {class_names[labels[idx].item()] if len(class_names) > predictions[idx].item() else 'Unknown'}") |
|
print("-" * 50) |
|
|
|
batch_window_index += 1 |
|
if args.inference_batch_limit > 0 and batch_window_index >= args.inference_batch_limit: |
|
break |
|
|
|
|
|
accuracy = sklearn.metrics.accuracy_score(all_labels, all_predictions) |
|
f1 = sklearn.metrics.f1_score(all_labels, all_predictions, average='weighted') |
|
precision = sklearn.metrics.precision_score(all_labels, all_predictions, average='weighted') |
|
recall = sklearn.metrics.recall_score(all_labels, all_predictions, average='weighted') |
|
|
|
print(f"Accuracy: {accuracy}") |
|
print(f"F1 Score: {f1}") |
|
print(f"Precision: {precision}") |
|
print(f"Recall: {recall}") |
|
|
|
with open("predictions.txt", "w") as f: |
|
for i in range(len(all_labels)): |
|
idx = int(i) |
|
f.write(f"Text: {test_data[0][idx]}\n") |
|
f.write(f"True Label: {all_labels[idx]}, Predicted Label: {all_predictions[idx]}\n") |
|
f.write("-" * 50 + "\n") |
|
|
|
with open("metrics.txt", "w") as f: |
|
f.write(f"Accuracy: {accuracy}\n") |
|
f.write(f"F1 Score: {f1}\n") |
|
f.write(f"Precision: {precision}\n") |
|
f.write(f"Recall: {recall}\n") |
|
|