File size: 2,871 Bytes
67d385b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import torch.nn as nn
import torch.optim as optim
from torchtext import data
from gensim.corpora import WikiCorpus
from transformers import GPT2Tokenizer, GPT2Model
from functions import *

# Define the model
# class GPT(nn.Module):
#     def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
#         super().__init__()
#         self.embedding = nn.Embedding(vocab_size, embedding_dim)
#         self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)
#         self.fc = nn.Linear(hidden_dim, vocab_size)
#         self.gpt2 = model
    
#     def forward(self, x):
#         # Embed the input
#         x = self.embedding(x)
#         # Pass through the GPT2 model
#         x = self.gpt2(x)
#         # Pass through the LSTM
#         x, _ = self.lstm(x)
#         # Pass through the fully connected layer
#         x = self.fc(x)
#         return x

# Load the GPT2 model
print('load gpt2 model')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')

# Load the data
print('load custom data')
# wiki_corpus_en = WikiCorpus('data/enwiki-latest-pages-articles.xml.bz2')
wiki_corpus_fr = WikiCorpus('data/frwiki-latest-pages-articles.xml.bz2')
# stackoverflow_corpus = data.TabularDataset('data/stackoverflow.csv', format='csv', fields=['text'])

# Preprocess the data
print('Preprocess the data')
# wiki_data_en = [text for text in wiki_corpus_en]
wiki_data_fr = [text for text in wiki_corpus_fr]
# stackoverflow_data = [text for text in stackoverflow_corpus]

# Convert the data to a format compatible with PyTorch
print('Convert the data to a format compatible with PyTorch')
# wiki_data_en = torch.tensor(wiki_data_en)
wiki_data_fr = torch.tensor(wiki_data_fr)
# stackoverflow_data = torch.tensor(stackoverflow_data)

# Define the Adam optimizer
print('Define the Adam optimizer')
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Define the loss function
print('Define the loss function')
criterion = nn.CrossEntropyLoss()

# Train the model
print('Train the model')
num_epochs=10
labels = torch.tensor([0, 1, 1, 0, 0, 1, 0, 1, 0, 1])

for epoch in range(num_epochs):
    print('epoch: ' + epoch)
    # Forward pass
    # outputs = model(wiki_data, stackoverflow_data)
    outputs = model(wiki_data_fr)
    # Calculate the loss
    loss = criterion(outputs, labels)
    # Backward pass
    loss.backward()
    # Update the parameters
    optimizer.step()
    # Reset the gradients
    optimizer.zero_grad()
    # Evaluate the model
    accuracy = evaluate(model, wiki_data_fr)
    # Save the model weights and states
    torch.save(model.state_dict(), 'model.pth')
    # Adjust the learning rate
    adjust_learning_rate(optimizer, epoch)
    # Print the loss and accuracy
    print('Epoch: {}, Loss: {:.4f}, Accuracy: {:.4f}'.format(epoch+1, loss.item(), accuracy))