|  | from transformers import GPT2LMHeadModel, AutoTokenizer | 
					
						
						|  | from transformers import AdamW, get_scheduler, set_seed | 
					
						
						|  | from datasets import load_dataset | 
					
						
						|  | from accelerate import Accelerator | 
					
						
						|  | import datasets, transformers | 
					
						
						|  | from huggingface_hub import Repository | 
					
						
						|  |  | 
					
						
						|  | from torch.utils.data import IterableDataset | 
					
						
						|  | from torch.utils.data.dataloader import DataLoader | 
					
						
						|  | from torch.utils.tensorboard import SummaryWriter | 
					
						
						|  | from argparse import Namespace | 
					
						
						|  | import torch | 
					
						
						|  | import logging | 
					
						
						|  | import wandb | 
					
						
						|  | import time | 
					
						
						|  |  | 
					
						
						|  | class ConstantLengthDataset(IterableDataset): | 
					
						
						|  | def __init__(self, tokenizer, dataset, seq_length=1024, | 
					
						
						|  | num_of_sequences=1024, chars_per_token=3.6): | 
					
						
						|  | self.tokenizer = tokenizer | 
					
						
						|  | self.concat_token_id = tokenizer.bos_token_id | 
					
						
						|  | self.dataset = dataset | 
					
						
						|  | self.seq_length = seq_length | 
					
						
						|  | self.input_characters = seq_length * chars_per_token * num_of_sequences | 
					
						
						|  | self.produced_samples = 0 | 
					
						
						|  | def __iter__(self): | 
					
						
						|  | iterator = iter(self.dataset) | 
					
						
						|  | more_examples = True | 
					
						
						|  | while more_examples: | 
					
						
						|  | buffer = [] | 
					
						
						|  | buffer_len = 0 | 
					
						
						|  | logger.debug(f'index: {accelerator.process_index}, filling up buffer, getting next element ({self.produced_samples})') | 
					
						
						|  | while True: | 
					
						
						|  | if buffer_len >= self.input_characters: | 
					
						
						|  | break | 
					
						
						|  | try: | 
					
						
						|  | buffer.append(next(iterator)['content']) | 
					
						
						|  | buffer_len += len(buffer[-1]) | 
					
						
						|  | self.produced_samples += 1 | 
					
						
						|  | except StopIteration: | 
					
						
						|  | more_examples = False | 
					
						
						|  | break | 
					
						
						|  | tokenized_inputs = tokenizer(buffer, truncation=False)['input_ids'] | 
					
						
						|  | logger.debug(f'index: {accelerator.process_index}, buffer tokenized') | 
					
						
						|  | all_token_ids = [] | 
					
						
						|  | for tokenized_input in tokenized_inputs: | 
					
						
						|  | all_token_ids.extend(tokenized_input + [self.concat_token_id]) | 
					
						
						|  | for i in range(0, len(all_token_ids), self.seq_length): | 
					
						
						|  | input_ids = all_token_ids[i : i + self.seq_length] | 
					
						
						|  | if len(input_ids) == self.seq_length: | 
					
						
						|  |  | 
					
						
						|  | yield torch.tensor(input_ids) | 
					
						
						|  |  | 
					
						
						|  | def setup_logging(project_name): | 
					
						
						|  | logger = logging.getLogger(__name__) | 
					
						
						|  | logging.basicConfig( | 
					
						
						|  | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | 
					
						
						|  | datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, handlers=[ | 
					
						
						|  | logging.FileHandler(f"log/debug_{accelerator.process_index}.log"), | 
					
						
						|  | logging.StreamHandler()]) | 
					
						
						|  | if accelerator.is_main_process: | 
					
						
						|  | wandb.init(project=project_name, config=args) | 
					
						
						|  | run_name = wandb.run.name | 
					
						
						|  | tb_writer = SummaryWriter() | 
					
						
						|  | tb_writer.add_hparams(vars(args), {'0': 0}) | 
					
						
						|  | logger.setLevel(logging.INFO) | 
					
						
						|  | datasets.utils.logging.set_verbosity_debug() | 
					
						
						|  | transformers.utils.logging.set_verbosity_info() | 
					
						
						|  | else: | 
					
						
						|  | tb_writer = None | 
					
						
						|  | run_name = '' | 
					
						
						|  | logger.setLevel(logging.ERROR) | 
					
						
						|  | datasets.utils.logging.set_verbosity_debug() | 
					
						
						|  | transformers.utils.logging.set_verbosity_error() | 
					
						
						|  | return logger, tb_writer, run_name | 
					
						
						|  |  | 
					
						
						|  | def create_dataloaders(dataset_name): | 
					
						
						|  | train_data = load_dataset(dataset_name+'-train', split="train", | 
					
						
						|  | streaming=True, chunksize=40<<20, error_bad_chunk=False) | 
					
						
						|  | train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, | 
					
						
						|  | seed=args.seed) | 
					
						
						|  | valid_data = load_dataset(dataset_name+'-valid', split="train", | 
					
						
						|  | streaming=True, chunksize=40<<20, error_bad_chunk=False) | 
					
						
						|  | train_dataset = ConstantLengthDataset(tokenizer, train_data, | 
					
						
						|  | seq_length=args.seq_length) | 
					
						
						|  | valid_dataset = ConstantLengthDataset(tokenizer, valid_data, | 
					
						
						|  | seq_length=args.seq_length) | 
					
						
						|  | train_dataloader=DataLoader(train_dataset, batch_size=args.train_batch_size) | 
					
						
						|  | eval_dataloader=DataLoader(valid_dataset, batch_size=args.valid_batch_size) | 
					
						
						|  | return train_dataloader, eval_dataloader | 
					
						
						|  |  | 
					
						
						|  | def get_grouped_params(model, no_decay=["bias", "LayerNorm.weight"]): | 
					
						
						|  | params_with_wd, params_without_wd = [], [] | 
					
						
						|  | for n, p in model.named_parameters(): | 
					
						
						|  | if any(nd in n for nd in no_decay): params_without_wd.append(p) | 
					
						
						|  | else: params_with_wd.append(p) | 
					
						
						|  | return [{'params': params_with_wd, 'weight_decay': args.weight_decay}, | 
					
						
						|  | {'params': params_without_wd, 'weight_decay': 0.0}] | 
					
						
						|  |  | 
					
						
						|  | def log_metrics(step, metrics): | 
					
						
						|  | logger.info(f"Step {step}: {metrics}") | 
					
						
						|  | if accelerator.is_main_process: | 
					
						
						|  | wandb.log(metrics) | 
					
						
						|  | [tb_writer.add_scalar(k, v, step) for k, v in metrics.items()] | 
					
						
						|  |  | 
					
						
						|  | def evaluate(): | 
					
						
						|  | model.eval() | 
					
						
						|  | losses = [] | 
					
						
						|  | for step, batch in enumerate(eval_dataloader): | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | outputs = model(batch, labels=batch) | 
					
						
						|  | loss = outputs.loss.repeat(args.valid_batch_size) | 
					
						
						|  | losses.append(accelerator.gather(loss)) | 
					
						
						|  | if args.max_eval_steps > 0 and step >= args.max_eval_steps: break | 
					
						
						|  | loss = torch.mean(torch.cat(losses)) | 
					
						
						|  | try: perplexity = torch.exp(loss) | 
					
						
						|  | except OverflowError: perplexity = float("inf") | 
					
						
						|  | return loss.item(), perplexity.item() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | project_name = 'transformersbook/codeparrot-small' | 
					
						
						|  | dataset_name = '../codeparrot' | 
					
						
						|  | config = {"train_batch_size": 12, | 
					
						
						|  | "valid_batch_size": 12, | 
					
						
						|  | "weight_decay": 0.1, | 
					
						
						|  | "shuffle_buffer": 1000, | 
					
						
						|  | "learning_rate": 5e-4, | 
					
						
						|  | "lr_scheduler_type": "cosine", | 
					
						
						|  | "num_warmup_steps": 2000, | 
					
						
						|  | "gradient_accumulation_steps": 1, | 
					
						
						|  | "max_train_steps": 150_000, | 
					
						
						|  | "max_eval_steps": -1, | 
					
						
						|  | "seq_length": 1024, | 
					
						
						|  | "seed": 1, | 
					
						
						|  | "save_checkpoint_steps": 15_000} | 
					
						
						|  | args = Namespace(**config) | 
					
						
						|  | set_seed(args.seed) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | accelerator = Accelerator() | 
					
						
						|  | samples_per_step = accelerator.state.num_processes * args.train_batch_size | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | logger, tb_writer, run_name = setup_logging(project_name.split("/")[1]) | 
					
						
						|  | logger.info(accelerator.state) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if accelerator.is_main_process: | 
					
						
						|  | hf_repo = Repository("./", clone_from=project_name, revision=run_name) | 
					
						
						|  | model = GPT2LMHeadModel.from_pretrained("./") | 
					
						
						|  | tokenizer = AutoTokenizer.from_pretrained("./") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | train_dataloader, eval_dataloader = create_dataloaders(dataset_name) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | optimizer = AdamW(get_grouped_params(model), lr=args.learning_rate) | 
					
						
						|  | lr_scheduler = get_scheduler(name=args.lr_scheduler_type, optimizer=optimizer, | 
					
						
						|  | num_warmup_steps=args.num_warmup_steps, | 
					
						
						|  | num_training_steps=args.max_train_steps,) | 
					
						
						|  | def get_lr(): return optimizer.param_groups[0]['lr'] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( | 
					
						
						|  | model, optimizer, train_dataloader, eval_dataloader) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model.train() | 
					
						
						|  | completed_steps = 0 | 
					
						
						|  | for step, batch in enumerate(train_dataloader, start=1): | 
					
						
						|  | logger.debug(f'{step}|{accelerator.process_index}|got batch') | 
					
						
						|  | loss = model(batch, labels=batch).loss | 
					
						
						|  | logger.debug(f'{step}|{accelerator.process_index}|forward pass done') | 
					
						
						|  | log_metrics(step, {'lr': get_lr(), 'samples': step*samples_per_step, | 
					
						
						|  | 'steps': completed_steps, 'loss/train': loss.item()}) | 
					
						
						|  | loss = loss / args.gradient_accumulation_steps | 
					
						
						|  | accelerator.backward(loss) | 
					
						
						|  | logger.debug(f'{step}|{accelerator.process_index}|backward pass done') | 
					
						
						|  | if step % args.gradient_accumulation_steps == 0: | 
					
						
						|  | optimizer.step() | 
					
						
						|  | logger.debug(f'{step}|{accelerator.process_index}|optimization done') | 
					
						
						|  | lr_scheduler.step() | 
					
						
						|  | optimizer.zero_grad() | 
					
						
						|  | completed_steps += 1 | 
					
						
						|  | if step % args.save_checkpoint_steps == 0: | 
					
						
						|  | logger.info('Evaluating model checkpoint') | 
					
						
						|  | eval_loss, perplexity = evaluate() | 
					
						
						|  | log_metrics(step, {'loss/eval': eval_loss, 'perplexity': perplexity}) | 
					
						
						|  | accelerator.wait_for_everyone() | 
					
						
						|  | unwrapped_model = accelerator.unwrap_model(model) | 
					
						
						|  | if accelerator.is_main_process: | 
					
						
						|  | logger.info('Saving model checkpoint') | 
					
						
						|  | unwrapped_model.save_pretrained("./") | 
					
						
						|  | hf_repo.push_to_hub(commit_message=f'step {step}') | 
					
						
						|  | model.train() | 
					
						
						|  | if completed_steps >= args.max_train_steps: | 
					
						
						|  | break | 
					
						
						|  | logger.debug(f'{step}|{accelerator.process_index}|train loop done') | 
					
						
						|  | if step==-1: | 
					
						
						|  | logger.setLevel(logging.DEBUG) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | logger.info('Evaluating and saving model after training') | 
					
						
						|  | eval_loss, perplexity = evaluate() | 
					
						
						|  | log_metrics(step, {'loss/eval': eval_loss, 'perplexity': perplexity}) | 
					
						
						|  | accelerator.wait_for_everyone() | 
					
						
						|  | unwrapped_model = accelerator.unwrap_model(model) | 
					
						
						|  | if accelerator.is_main_process: | 
					
						
						|  | unwrapped_model.save_pretrained("./") | 
					
						
						|  | hf_repo.push_to_hub(commit_message=f'final model') | 
					
						
						|  |  |