image-caption / train.py
Lwasinam's picture
Upload 5 files
22ca2be verified
raw
history blame contribute delete
No virus
14.9 kB
from model import build_transformer
from dataset import BilingualDataset, causal_mask
from config import get_config, get_weights_file_path
import datasets
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import IterableDataset, DataLoader, random_split
from torch.optim.lr_scheduler import LambdaLR
import warnings
from tqdm import tqdm
import os
from pathlib import Path
# Huggingface datasets and tokenizers
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
import torchmetrics
import wandb
import accelerate
from torch.utils.tensorboard import SummaryWriter
from safetensors.torch import load_model, save_model
from accelerate import Accelerator
from transformers import GPT2TokenizerFast
import threading
def greedy_decode(model, source, source_mask, tokenizer_tgt, max_len, device):
sos_idx = tokenizer_tgt.convert_tokens_to_ids('[SOS]')
eos_idx = tokenizer_tgt.convert_tokens_to_ids('[EOS]')
# Precompute the encoder output and reuse it for every step
encoder_output = model.module.encode(source, None)
# Initialize the decoder input with the sos token
decoder_input = torch.empty(1, 1).fill_(sos_idx).long().to(device)
while True:
if decoder_input.size(1) == max_len:
break
# build mask for target
decoder_mask = causal_mask(decoder_input.size(1)).long().to(device)
# calculate output
out = model.module.decode(encoder_output, source_mask, decoder_input, decoder_mask)
# print(f'out: {out.shape}')
# Get next token probabilities with temperature applied
logits = model.module.project(out[:, -1])
probabilities = F.softmax(logits, dim=-1)
# Greedily select the next word
next_word = torch.argmax(probabilities, dim=1)
# Append next word
decoder_input = torch.cat([decoder_input, next_word.unsqueeze(0)], dim=1)
# # get next token
# prob = model.project(out[:, -1])
# _, next_word = torch.max(prob, dim=1)
# # print(f'prob: {prob.shape}')
# decoder_input = torch.cat(
# [decoder_input, torch.empty(1, 1).long().fill_(next_word.item()).to(device)], dim=1
# )
if next_word.item() == eos_idx:
break
return decoder_input.squeeze(0)
def run_validation(model, validation_ds,tokenizer_tgt, max_len, device, print_msg, global_step, num_examples=3):
model.eval()
count = 0
source_texts = []
expected = []
predicted = []
try:
# get the console window width
with os.popen('stty size', 'r') as console:
_, console_width = console.read().split()
console_width = int(console_width)+_
except:
# If we can't get the console width, use 80 as default
console_width = 80
with torch.no_grad():
for batch in validation_ds:
count += 1
encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
encoder_mask = batch["encoder_mask"].to(device) # (b, 1, 1, seq_len)
# check that the batch size is 1
assert encoder_input.size(
0) == 1, "Batch size must be 1 for validation"
model_out = greedy_decode(model, encoder_input, None, tokenizer_tgt, max_len, device)
# source_text = batch["src_text"][0]
target_text = batch["tgt_text"][0]
model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())
# source_texts.append(source_text)
expected.append(target_text)
predicted.append(model_out_text)
# Print the source, target and model output
print_msg('-'*console_width)
# print_msg(f"{f'SOURCE: ':>12}{source_text}")
print_msg(f"{f'TARGET: ':>12}{target_text}")
print_msg(f"{f'PREDICTED: ':>12}{model_out_text}")
if count == num_examples:
print_msg('-'*console_width)
break
# if writer:
# # Evaluate the character error rate
# # Compute the char error rate
# metric = torchmetrics.CharErrorRate()
# cer = metric(predicted, expected)
# writer.add_scalar('validation cer', cer, global_step)
# writer.flush()
# # Compute the word error rate
# metric = torchmetrics.WordErrorRate()
# wer = metric(predicted, expected)
# writer.add_scalar('validation wer', wer, global_step)
# writer.flush()
# # Compute the BLEU metric
# metric = torchmetrics.BLEUScore()
# bleu = metric(predicted, expected)
# writer.add_scalar('validation BLEU', bleu, global_step)
# writer.flush()
def get_all_sentences(ds):
for item in ds:
yield item['text']
def batch_iterator(data):
for i in range(0, len(data)):
yield data[i]['text']
# Assuming batch_iterator is a function that yields batches
def tqdm_batch_iterator(data, *args, **kwargs):
for batch in tqdm(batch_iterator(data, *args, **kwargs), total=len(data)):
yield batch
def get_or_build_tokenizer(config, ds):
tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2", unk_token ='[UNK]', bos_token = '[SOS]', eos_token = '[EOS]' , pad_token = '[PAD]')
return tokenizer
# tokenizer_path = Path(config['tokenizer_file'])
# if not Path.exists(tokenizer_path):
# # Most code taken from: https://huggingface.co/docs/tokenizers/quicktour
# tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
# tokenizer.pre_tokenizer = Whitespace()
# trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
# tokenizer.train_from_iterator(get_all_sentences(ds), trainer=trainer)
# tokenizer.save(str(tokenizer_path))
# else:
# tokenizer = Tokenizer.from_file(str(tokenizer_path))
# return tokenizer
def get_ds(config):
# It only has the train split, so we divide it overselves
# ds_raw = load_dataset("HausaNLP/HausaVG", split='train+validation+test+challenge_test')
train_ds_raw = load_dataset("MMInstruction/M3IT", 'coco', split ='train')
val_ds_raw = load_dataset("MMInstruction/M3IT", 'coco', split ='validation[:2%]')
# ds_raw = load_dataset('opus_books', f"{config['lang_src']}-{config['lang_tgt']}", split='train')
# Build tokenizers
tokenizer_tgt = get_or_build_tokenizer(config, train_ds_raw,)
seed = 20 # You can choose any integer as your seed
torch.manual_seed(seed)
# # Keep 90% for training, 10% for validation
# train_ds_size = int(0.9 * len(ds_raw))
# val_ds_size = len(ds_raw) - train_ds_size
# train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])
train_ds = BilingualDataset(train_ds_raw, tokenizer_tgt, config['seq_len'])
val_ds = BilingualDataset(val_ds_raw, tokenizer_tgt, config['seq_len'])
train_dataloader = DataLoader(train_ds,batch_size=config['batch_size'], shuffle=True )
val_dataloader = DataLoader(val_ds, batch_size=1,shuffle=True )
return train_dataloader, val_dataloader, tokenizer_tgt
def get_model(config, vocab_tgt_len):
model = build_transformer(vocab_tgt_len, config['seq_len'], d_model=config['d_model'])
return model
def train_model(config):
accelerator = Accelerator()
print()
wandb.login(key = 'c20a1022142595d7d1324fdc53b3ccb34c0ded22')
wandb.init(project="Vision", name=config['project_name'])
# Initialize WandB configuration
wandb.config.epochs = config['num_epochs']
wandb.config.batch_size = config['batch_size']
wandb.config.learning_rate = config['lr']
# Define the devic
# Define the device
device = accelerator.device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# Make sure the weights folder exists
Path(config['model_folder']).mkdir(parents=True, exist_ok=True)
train_dataloader, val_dataloader, tokenizer_tgt = get_ds(config)
model = get_model(config, len(tokenizer_tgt)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], betas=(0.9, 0.98),eps=1e-9)
model, optimizer, train_dataloader, val_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, val_dataloader
)
# If the user specified a model to preload before training, load it
initial_epoch = 0
global_step = 0
def save_models():
accelerator.save_state(output_dir=f'/kaggle/working/weights/tmodel_00')
print(f'saving global step {global_step}')
if config['preload']:
model_filename = get_weights_file_path(config, config['preload'])
print(f'Preloading model {model_filename}')
accelerator.load_state(model_filename)
initial_epoch = 4
# state = torch.load(model_filename)
# model.load_state_dict(state['model_state_dict'])
# initial_epoch = state['epoch'] + 1
# optimizer.load_state_dict(state['optimizer_state_dict'])
# global_step = state['global_step']
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_tgt.convert_tokens_to_ids('[PAD]'), label_smoothing=0.1).to(device)
for epoch in range(initial_epoch, config['num_epochs']):
# timer = threading.Timer(5*60, save_models)
# timer.start()
model.train()
batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
for batch in batch_iterator:
encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
decoder_input = batch["decoder_input"].to(device) # (B, seq_len)
encoder_mask = batch["encoder_mask"].to(device) # (B, 1, 1, seq_len)
decoder_mask = batch["decoder_mask"].to(device) # (B, 1, seq_len, seq_len)
# Run the tensors through the encoder, decoder and the projection layer
encoder_output = model.module.encode(encoder_input, None) # (B, seq_len, d_model)
decoder_output = model.module.decode(encoder_output, None, decoder_input, decoder_mask) # (B, seq_len, d_model)
proj_output = model.module.project(decoder_output)
# (B, seq_len, vocab_size)
# Compare the output with the label
label = batch["label"].to(device) # (B, seq_len)
# Compute the loss using a simple cross entropy
loss = loss_fn(proj_output.view(-1, len(tokenizer_tgt)), label.view(-1))
batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})
# Log the loss
wandb.log({"Training Loss": loss.item(), "Global Step": global_step})
# # Backpropagate the loss
# loss.backward()
accelerator.backward(loss)
# Update the weights
optimizer.step()
optimizer.zero_grad(set_to_none=True)
global_step += 1
# if global_step == 20000 or global_step == 25000:
# print(f'saved state at {global_step}')
# accelerator.save_state(output_dir=f'/kaggle/working/weights/tmodel_{epoch:02d}')
if global_step == 1000 or global_step == 5000 or global_step == 10000 or global_step == 15000 or global_step == 20000 or global_step == 30000:
run_validation(model, val_dataloader, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step)
model.train()
# # Run validation at the end of every epoch
# Save the model at the end of every epoch
model_filename = get_weights_file_path(config, f"{epoch:02d}")
# torch.save({
# 'epoch': epoch,
# 'model_state_dict': model.state_dict(),
# 'optimizer_state_dict': optimizer.state_dict(),
# 'global_step': global_step
# }, model_filename)
# accelerator.save_model(model, model_filename)
accelerator.save_state(output_dir=f'/kaggle/working/weights/tmodel_{epoch:02d}')
# run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step, writer)
model.eval()
eval_loss = 0.0
#accelerate
accurate = 0
num_elems = 0
# batch_iterator = tqdm(v_dataloader, desc=f"Processing Epoch {epoch:02d}")
with torch.no_grad():
batch_itere = tqdm(val_dataloader, desc=f"Processing loss")
for batch in batch_itere:
encoder_input = batch['encoder_input'].to(device) # (b, seq_len)
decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)
# Run the tensors through the encoder, decoder and the projection layer
encoder_output = model.module.encode(encoder_input, None) # (B, seq_len, d_model)
decoder_output = model.module.decode(encoder_output, None, decoder_input, decoder_mask)# (B, seq_len, d_model)
proj_output = model.module.project(decoder_output)
# (B, seq_len, vocab_size)
# Compare the output with the label
# label = batch['label'].to(device) # (B, seq_len)
proj_output, label = accelerator.gather_for_metrics((
proj_output, batch["label"]
))
# Compute the loss using a simple cross entropy
ls = loss_fn(proj_output.view(-1, len(tokenizer_tgt)), label.view(-1))
batch_itere.set_postfix({"loss": f"{ls.item():6.3f}"})
eval_loss += ls
# loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
avg_val_loss = eval_loss / len(val_dataloader)
accelerator.print(f"Epoch {epoch},Validation Loss: {avg_val_loss})Validation Loss: {avg_val_loss}")
# print(f'Epoch {epoch},Validation Loss: {avg_val_loss.item()}')
wandb.log({"Validation Loss": avg_val_loss.item(), "Global Step": global_step})
run_validation(model, val_dataloader, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step)
if __name__ == '__main__':
warnings.filterwarnings("ignore")
config = get_config()
train_model(config)