bog-001 / main.py
poka's picture
add bog files
67d385b
import torch
import torch.nn as nn
import torch.optim as optim
from torchtext import data
from gensim.corpora import WikiCorpus
from transformers import GPT2Tokenizer, GPT2Model
from functions import *
# Define the model
# class GPT(nn.Module):
# def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
# super().__init__()
# self.embedding = nn.Embedding(vocab_size, embedding_dim)
# self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)
# self.fc = nn.Linear(hidden_dim, vocab_size)
# self.gpt2 = model
# def forward(self, x):
# # Embed the input
# x = self.embedding(x)
# # Pass through the GPT2 model
# x = self.gpt2(x)
# # Pass through the LSTM
# x, _ = self.lstm(x)
# # Pass through the fully connected layer
# x = self.fc(x)
# return x
# Load the GPT2 model
print('load gpt2 model')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
# Load the data
print('load custom data')
# wiki_corpus_en = WikiCorpus('data/enwiki-latest-pages-articles.xml.bz2')
wiki_corpus_fr = WikiCorpus('data/frwiki-latest-pages-articles.xml.bz2')
# stackoverflow_corpus = data.TabularDataset('data/stackoverflow.csv', format='csv', fields=['text'])
# Preprocess the data
print('Preprocess the data')
# wiki_data_en = [text for text in wiki_corpus_en]
wiki_data_fr = [text for text in wiki_corpus_fr]
# stackoverflow_data = [text for text in stackoverflow_corpus]
# Convert the data to a format compatible with PyTorch
print('Convert the data to a format compatible with PyTorch')
# wiki_data_en = torch.tensor(wiki_data_en)
wiki_data_fr = torch.tensor(wiki_data_fr)
# stackoverflow_data = torch.tensor(stackoverflow_data)
# Define the Adam optimizer
print('Define the Adam optimizer')
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Define the loss function
print('Define the loss function')
criterion = nn.CrossEntropyLoss()
# Train the model
print('Train the model')
num_epochs=10
labels = torch.tensor([0, 1, 1, 0, 0, 1, 0, 1, 0, 1])
for epoch in range(num_epochs):
print('epoch: ' + epoch)
# Forward pass
# outputs = model(wiki_data, stackoverflow_data)
outputs = model(wiki_data_fr)
# Calculate the loss
loss = criterion(outputs, labels)
# Backward pass
loss.backward()
# Update the parameters
optimizer.step()
# Reset the gradients
optimizer.zero_grad()
# Evaluate the model
accuracy = evaluate(model, wiki_data_fr)
# Save the model weights and states
torch.save(model.state_dict(), 'model.pth')
# Adjust the learning rate
adjust_learning_rate(optimizer, epoch)
# Print the loss and accuracy
print('Epoch: {}, Loss: {:.4f}, Accuracy: {:.4f}'.format(epoch+1, loss.item(), accuracy))