import time
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
import torch
import gzip
import json
import numpy as np
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize, sent_tokenize
from nltk.stem import PorterStemmer, WordNetLemmatizer
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
import argparse
from torch import nn
class TextClassificationModel(nn.Module):
def __init__(self, vocab_size, embed_dim, num_class, vocab):
self.model = super(TextClassificationModel, self).__init__()
self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
self.fc = nn.Linear(embed_dim, num_class)
self.vocab_size = vocab_size
self.emsize = embed_dim
self.num_class = num_class
self.vocab = vocab
self.text_pipeline = self.tokenizer
self.tokenizer_convert = get_tokenizer("basic_english")
def tokenizer(self, text):
return self.vocab(self.tokenizer_convert(text))
def init_weights(self):
initrange = 0.5
self.embedding.weight.data.uniform_(-initrange, initrange)
self.fc.weight.data.uniform_(-initrange, initrange)
def forward(self, text, offsets):
embedded = self.embedding(text, offsets)
return self.fc(embedded)
def train_model(self, train_dataloader, valid_dataloader):
total_accu = None
for epoch in range(1, EPOCHS + 1):
epoch_start_time = time.time()
total_acc, total_count = 0, 0
log_interval = 500
start_time = time.time()
for idx, (label, text, offsets) in enumerate(train_dataloader):
predicted_label = self(text, offsets)
loss = criterion(predicted_label, label)
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
total_acc += (predicted_label.argmax(1) == label).sum().item()
total_count += label.size(0)
if idx % log_interval == 0 and idx > 0:
elapsed = time.time() - start_time
"| epoch {:3d} | {:5d}/{:5d} batches "
"| accuracy {:8.3f}".format(
epoch, idx, len(train_dataloader), total_acc / total_count
total_acc, total_count = 0, 0
start_time = time.time()
accu_val = self.evaluate(valid_dataloader)
if total_accu is not None and total_accu > accu_val:
total_accu = accu_val
print("-" * 59)
"| end of epoch {:3d} | time: {:5.2f}s | "
"valid accuracy {:8.3f} ".format(
epoch, time.time() - epoch_start_time, accu_val
print("-" * 59)
def save_model(self, file_path):
model_state = {
'state_dict': self.state_dict(),
'vocab_size': self.vocab_size,
'embed_dim': self.emsize,
'num_class': self.num_class,
'vocab': self.vocab
torch.save(model_state, file_path)
print("Model saved successfully.")
def load_model(self, file_path):
model_state = torch.load(file_path, map_location=torch.device('cpu'))
vocab_size = model_state['vocab_size']
embed_dim = model_state['embed_dim']
num_class = model_state['num_class']
vocab = model_state['vocab']
model = TextClassificationModel(vocab_size, embed_dim, num_class, vocab)
print("Model loaded successfully.")
return model
def evaluate(self, dataloader):
total_acc, total_count = 0, 0
with torch.no_grad():
for idx, (label, text, offsets) in enumerate(dataloader):
predicted_label = self(text, offsets)
loss = criterion(predicted_label, label)
total_acc += (predicted_label.argmax(1) == label).sum().item()
total_count += label.size(0)
return total_acc / total_count
def predict(self, text):
with torch.no_grad():
text = torch.tensor(self.text_pipeline(text))
output = self(text, torch.tensor([0]))
return output
def read_gz_json(file_path):
with gzip.open(file_path, 'rt', encoding='utf-8') as f:
data = json.load(f)
for obj in data:
yield obj['text'], obj['category']
def preprocess_text(text):
sentences = sent_tokenize(text)
return sentences
def data_iter(file_paths, categories):
categories = np.array(categories)
for path in file_paths:
for text, category in TextClassificationModel.read_gz_json(path):
sentences = TextClassificationModel.preprocess_text(text)
for sentence in sentences:
yield np.where(categories == category)[0][0], sentence
def collate_batch(batch):
label_list, text_list, offsets = [], [], [0]
for _label, _text in batch:
processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
label_list = torch.tensor(label_list, dtype=torch.int64)
offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
text_list = torch.cat(text_list)
return label_list.to(device), text_list.to(device), offsets.to(device)
def parse_arguments():
parser = argparse.ArgumentParser(description="Text Classification Model")
parser.add_argument("--train_path", type=str, nargs='+', required=True, help="Path to the training data")
parser.add_argument("--test_path", type=str, nargs='+', required=True, help="Path to the test data")
parser.add_argument("--epochs", type=int, default=5, help="Number of epochs for training")
parser.add_argument("--lr", type=float, default=3, help="Learning rate")
parser.add_argument("--batch_size", type=int, default=64, help="Batch size for training")
return parser.parse_args()
if __name__ == '__main__':
args = parse_arguments()
categories = ['Geography', 'Religion', 'Philosophy', 'Trash', 'Mythology', 'Literature', 'Science', 'Social Science', 'History', 'Current Events', 'Fine Arts']
test_path = args.test_path
train_path = args.train_path
tokenizer = get_tokenizer("basic_english")
train_iter = iter(TextClassificationModel.data_iter(train_path, categories))
def yield_tokens(data_iter):
for _, text in data_iter:
yield tokenizer(text)
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataloader = DataLoader(
train_iter, batch_size=8, shuffle=False, collate_fn=TextClassificationModel.collate_batch
train_iter = iter(TextClassificationModel.data_iter(train_path, categories))
classes = set([label for (label, text) in train_iter])
num_class = len(classes)
vocab_size = len(vocab)
emsize = 64
model = TextClassificationModel(vocab_size, emsize, num_class).to(device)
EPOCHS = args.epochs
LR = args.lr
BATCH_SIZE = args.batch_size
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None
train_iter = iter(TextClassificationModel.data_iter(train_path, categories))
test_iter = iter(TextClassificationModel.data_iter(test_path, categories))
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)
num_train = int(len(train_dataset) * 0.95)
split_train_, split_valid_ = random_split(
train_dataset, [num_train, len(train_dataset) - num_train]
train_dataloader = DataLoader(
split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=TextClassificationModel.collate_batch
valid_dataloader = DataLoader(
split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=TextClassificationModel.collate_batch
test_dataloader = DataLoader(
test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=TextClassificationModel.collate_batch
print("Checking the results of test dataset.")
accu_test = model.evaluate(test_dataloader)
print("test accuracy {:8.3f}".format(accu_test))