nice_model_test / Train.py
MAKILINGDING's picture
Upload 5 files
a3b29ff verified
import torch
import math
from torch import nn
import torch.nn.functional as F
import Tokenizer
from datasets import load_dataset
import time
import json
from transformers import AdamW, get_scheduler
from sklearn.model_selection import train_test_split
from torch.nn.utils.rnn import pad_sequence
### TOKENIZER ##########################################################################################################
vocabulary = Tokenizer.get_vocabulary()
token_vocabulary = Tokenizer.get_token_vocabulary()
### TRANSFORMER ########################################################################################################
d_model = 384
num_heads = 6
drop_prob = 0.1
batch_size = 38 # batch_size must be divisible by num_heads / len(train_input) must be divisible by batch_size
max_sequence_length = 256
ffn_hidden = d_model * 4
num_layers = 6
save_path = 'models/my_model.pt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def scaled_dot_product(q, k, v, mask=None):
d_k = q.size()[-1]
scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
if mask is not None:
scaled += mask.to(device)
attention = F.softmax(scaled, dim=-1)
values = torch.matmul(attention, v)
return values, attention
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.qkv_layer = nn.Linear(d_model, 3 * d_model)
self.linear_layer = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
batch_size, max_sequence_length, d_model = x.size()
qkv = self.qkv_layer(x)
qkv = qkv.reshape(batch_size, max_sequence_length, self.num_heads, 3 * self.head_dim)
qkv = qkv.permute(0, 2, 1, 3)
q, k, v = qkv.chunk(3, dim=-1)
values, attention = scaled_dot_product(q, k, v, mask)
values = values.reshape(batch_size, max_sequence_length, self.num_heads * self.head_dim)
out = self.linear_layer(values)
return out
class LayerNormalization(nn.Module):
def __init__(self, parameters_shape, eps=1e-5):
super().__init__()
self.parameters_shape = parameters_shape
self.eps = eps
self.gamma = nn.Parameter(torch.ones(parameters_shape))
self.beta = nn.Parameter(torch.zeros(parameters_shape))
def forward(self, inputs):
dims = [-(i + 1) for i in range(len(self.parameters_shape))]
mean = inputs.mean(dim=dims, keepdim=True)
var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
std = (var + self.eps).sqrt()
y = (inputs - mean) / std
out = self.gamma * y + self.beta
return out
class PositionwiseFeedForward(nn.Module):
def __init__(self, d_model, hidden, drop_prob=0.1):
super(PositionwiseFeedForward, self).__init__()
self.linear1 = nn.Linear(d_model, hidden)
self.linear2 = nn.Linear(hidden, d_model)
self.dropout = nn.Dropout(p=drop_prob)
def forward(self, x):
x = self.linear1(x)
x = F.gelu(x)
x = self.dropout(x)
x = self.linear2(x)
return x
class PositionalEncoding(nn.Module):
def __init__(self, d_model):
super().__init__()
self.d_model = d_model
def forward(self, sequence_length):
even_i = torch.arange(0, self.d_model, 2).float()
denominator = torch.pow(10000, even_i / self.d_model)
position = torch.arange(sequence_length).reshape(sequence_length, 1)
even_PE = torch.sin(position / denominator)
odd_PE = torch.cos(position / denominator)
stacked = torch.stack([even_PE, odd_PE], dim=2)
PE = torch.flatten(stacked, start_dim=1, end_dim=2)
return PE
class TransformerLayer(nn.Module):
def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
super(TransformerLayer, self).__init__()
self.attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
self.norm1 = LayerNormalization(parameters_shape=[d_model])
self.dropout1 = nn.Dropout(p=drop_prob)
self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
self.norm2 = LayerNormalization(parameters_shape=[d_model])
self.dropout2 = nn.Dropout(p=drop_prob)
def forward(self, x, original_inputs):
input_pad_mask = (original_inputs != 0)
index = torch.argmax(input_pad_mask.sum(dim=1))
max_length = 0
for element in original_inputs[index]:
if element != 0:
max_length += 1
else:
break
seq_len = x.size()[1]
causal_mask = torch.tril(torch.ones(seq_len, seq_len))
mask = torch.where(causal_mask == 0, torch.tensor(float('-inf')), causal_mask)
mask[mask == 1] = 0
mask[max_length:, max_length:] = float('-inf')
residual_x = x
x = self.attention(x, mask=mask)
# x = self.dropout1(x)
x = self.norm1(x + residual_x)
residual_x = x
x = self.ffn(x)
# x = self.dropout2(x)
x = self.norm2(x + residual_x)
return x
class SequentialTransformer(nn.Sequential):
def forward(self, *inputs):
x, original_inputs = inputs
for module in self._modules.values():
new_x = module(x, original_inputs)
return new_x
class Transformer(nn.Module):
def __init__(self, d_model, ffn_hidden, num_heads, drop_prob, num_layers):
super().__init__()
self.d_model = d_model
self.token_embedding = nn.Embedding(len(vocabulary), d_model)
# self.token_embedding = nn.Embedding(len(true_vocabulary), d_model)
self.positional_encoding = PositionalEncoding(d_model)
self.layers = SequentialTransformer(*[TransformerLayer(d_model, ffn_hidden, num_heads, drop_prob)
for _ in range(num_layers)])
self.output_layers = nn.Linear(d_model, len(vocabulary))
# self.output_layers = nn.Linear(d_model, len(true_vocabulary))
def forward(self, x, targets):
original_inputs = x
token_embeddings = self.token_embedding(x) * math.sqrt(self.d_model)
pos_encoding = self.positional_encoding(x.size()[1]).to(device).unsqueeze(0).repeat(x.size(0), 1, 1)
x = token_embeddings + pos_encoding
x = self.layers(x, original_inputs)
logits = self.output_layers(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
def generate(self, x):
original_inputs = x
token_embeddings = self.token_embedding(x) * math.sqrt(self.d_model)
pos_encoding = self.positional_encoding(x.size()[1]).to(device).unsqueeze(0).repeat(x.size(0), 1, 1)
x = token_embeddings + pos_encoding
x = self.layers(x, original_inputs)
x = self.output_layers(x)
return F.softmax(x, dim=-1)
### DATA PREPROCESSING #################################################################################################
print('Data Preprocessing...')
start_time = time.time()
def save_tokenized_data(name, tokenized_dataset):
with open(name, 'w') as file:
json.dump(tokenized_dataset, file)
def load_tokenized_data(name):
with open(f'tokenized_datasets/{name}', 'r') as file:
loaded_tokenized_data = json.load(file)
return loaded_tokenized_data
# raw_dataset = load_dataset('c4', 'realnewslike') #***********************************#
# raw_dataset = raw_dataset['train'].select(range(round(len(raw_dataset['train']) / 1000)))
# raw_dataset = [Tokenizer.tokenize_sequence(raw_dataset['text'][i]) for i in range(len(raw_dataset['text']))]
# save_tokenized_data('tokenized_datasets/c4_realnewslike.json', raw_dataset)
raw_dataset = load_tokenized_data('c4_realnewslike.json') #***********************************#
token_dataset = []
for i in range(len(raw_dataset)):
for j in range(len(raw_dataset[i])):
token_dataset.append(raw_dataset[i][j])
token_dataset = token_dataset[:round(max_sequence_length * math.floor(len(token_dataset) / max_sequence_length))]
train_input = [[] for i in range(math.floor(len(token_dataset) / (max_sequence_length * 2)))]
train_output = [[] for i in range(math.floor(len(token_dataset) / (max_sequence_length * 2)))]
for i in range(0, len(token_dataset) - max_sequence_length, max_sequence_length * 2):
for j in range(max_sequence_length):
train_input[round(i / (max_sequence_length * 2))].append(token_dataset[i + j])
train_output[round(i / (max_sequence_length * 2))].append(token_dataset[i + j + max_sequence_length])
print(f'len(train_input) = {len(train_input)}')
# # raw_train_dataset, raw_eval_dataset = train_test_split(raw_dataset['train'].select(range(round(len(raw_dataset['train']) / 25))), test_size=0.2)
train_input = [seq[:max_sequence_length] if len(seq) > max_sequence_length else seq for seq in train_input]
train_output = [seq[:max_sequence_length] if len(seq) > max_sequence_length else seq for seq in train_output]
train_input = [torch.tensor(seq, dtype=torch.long) for seq in train_input]
train_output = [torch.tensor(seq, dtype=torch.long) for seq in train_output]
# train_input = [Tokenizer.pad_to_length(seq, max_sequence_length) for seq in train_input]
# train_output = [Tokenizer.pad_to_length(seq, max_sequence_length) for seq in train_output]
train_dataset = [(train_input[i], train_output[i]) for i in range(len(train_input))]
# train_dataset = [pad_sequence(train_dataset[i], batch_first=True, padding_value=0) for i in range(len(train_dataset))]
train_batch = [[[] for i in range(round(len(train_dataset) / batch_size))] for j in range(2)]
train_batch_count = 0
for i in range(0, len(train_dataset), batch_size):
for j in range(batch_size):
train_batch[0][train_batch_count].append(train_dataset[i + j][0])
train_batch[1][train_batch_count].append(train_dataset[i + j][1])
train_batch_count += 1
### TRAINING ###########################################################################################################
print('Training...')
model = Transformer(d_model, ffn_hidden, num_heads, drop_prob, num_layers)
print(f'model parameters: {sum(p.numel() for p in model.parameters())}')
model.to(device)
epochs = 5
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
num_training_steps = epochs * len(train_dataset)
lr_scheduler = get_scheduler(
name='linear',
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=num_training_steps
)
train_epoch_average_loss = []
train_loss_total = 0
for epoch in range(epochs):
model.train()
train_loss = 0
for i in range(len(train_batch[0])):
inputs = torch.stack(train_batch[0][i]).to(device)
labels = torch.stack(train_batch[1][i]).to(device)
logits, loss = model.forward(inputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
train_loss_total += loss
if i % 10 == 0:
print('TRAINING...')
print(f'EPOCH {epoch}, batch {i}/{len(train_batch[0])}')
print(f'loss: {loss}')
train_epoch_average_loss.append((train_loss_total / len(train_batch[0])))
train_loss_total = 0
# model.eval()
# eval_loss = 0
# with torch.no_grad():
# for i, batch in enumerate(eval_dataset):
# inputs = batch[0].unsqueeze(0).to(device)
# labels = batch[1].unsqueeze(0).to(device)
# logits, loss = model(inputs, labels)
# if i % 10 == 0:
# print('EVALUATING...')
# print(f'EPOCH {epoch}, batch {i}/{len(eval_dataset)}')
# print(f'loss: {loss}')
for i in range(len(train_epoch_average_loss)):
print(f'EPOCH {i} AVERAGE LOSS: {train_epoch_average_loss[i]}')
torch.save(model.state_dict(), save_path)
end_time = time.time()
total_time = end_time - start_time
print(f'{total_time} seconds')