from transformers import GPT2LMHeadModel, AutoTokenizer from transformers import AdamW, get_scheduler, set_seed #from transformers.training_args import OptimizerNames from datasets import load_dataset from accelerate import Accelerator import datasets, transformers from huggingface_hub import Repository from torch.utils.data import IterableDataset from torch.utils.data.dataloader import DataLoader from torch.utils.tensorboard import SummaryWriter from argparse import Namespace import torch import logging import wandb import sys from torchinfo import summary class ConstantLengthDataset(IterableDataset): def __init__(self, tokenizer, dataset, seq_length=1024, num_of_sequences=1024, chars_per_token=3.6): self.tokenizer = tokenizer self.concat_token_id = tokenizer.bos_token_id self.dataset = dataset self.seq_length = seq_length self.input_characters = seq_length * chars_per_token * num_of_sequences def __iter__(self): iterator = iter(self.dataset) more_examples = True while more_examples: buffer, buffer_len = [], 0 while True: if buffer_len >= self.input_characters: break try: #buffer.append(next(iterator)['content']) buffer.append(next(iterator)['text']) buffer_len += len(buffer[-1]) #print('iter buffer size:', sys.getsizeof(buffer), buffer_len) except StopIteration: more_examples = False break tokenized_inputs = tokenizer(buffer, truncation=False)['input_ids'] all_token_ids = [] for tokenized_input in tokenized_inputs: all_token_ids.extend(tokenized_input + [self.concat_token_id]) for i in range(0, len(all_token_ids), self.seq_length): input_ids = all_token_ids[i : i + self.seq_length] if len(input_ids) == self.seq_length: yield torch.tensor(input_ids) def setup_logging(project_name): logger = logging.getLogger(__name__) logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, handlers=[ logging.FileHandler(f"log/debug_{accelerator.process_index}.log"), logging.StreamHandler()]) if accelerator.is_main_process: # we only want to setup logging once wandb.init(project=project_name, config=args) run_name = wandb.run.name tb_writer = SummaryWriter() tb_writer.add_hparams(vars(args), {'0': 0}) logger.setLevel(logging.INFO) datasets.utils.logging.set_verbosity_info() transformers.utils.logging.set_verbosity_info() else: tb_writer = None run_name = '' logger.setLevel(logging.ERROR) datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() return logger, tb_writer, run_name def create_dataloaders(dataset_name, args): #@ds_kwargs = {"streaming":True, "chunksize":40<<20, "error_bad_chunk":False} ##train_data = load_dataset(dataset_name+'-train', split='train', **ds_kwargs) #@train_data = load_dataset(dataset_name+'-train', split='train', data_files="*.json.gz", **ds_kwargs) ds_kwargs = {"streaming":True, "chunksize":40<<20} #train_data = load_dataset('text', data_files={'train': ["wiki_mrph.txt"]}, # split="train[:90%]", **ds_kwargs) train_data = load_dataset('text', data_files={'train': ["../ja-test-data/wiki_mrph_split_aa"]}, split='train', **ds_kwargs) print(train_data) #valid_data = load_dataset('text', data_files={'train': ["wiki_mrph.txt"]}, # split="train[-10%:]", **ds_kwargs) valid_data = load_dataset('text', data_files={'train': ["../ja-test-data/wiki_mrph_split_ab"]}, split='train', **ds_kwargs) print(valid_data) #train_data = chunked((x for x in dataset), 1000) #valid_data = chunked((x for x in dataset), 1000) train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed) #@valid_data = load_dataset(dataset_name+'-valid', split="train", **ds_kwargs) train_dataset = ConstantLengthDataset(tokenizer, train_data, seq_length=args.seq_length) valid_dataset = ConstantLengthDataset(tokenizer, valid_data, seq_length=args.seq_length) train_dataloader=DataLoader(train_dataset, batch_size=args.train_batch_size) eval_dataloader=DataLoader(valid_dataset, batch_size=args.valid_batch_size) return train_dataloader, eval_dataloader def get_grouped_params(model, args, no_decay=["bias", "LayerNorm.weight"]): params_with_wd, params_without_wd = [], [] for n, p in model.named_parameters(): if any(nd in n for nd in no_decay): params_without_wd.append(p) else: params_with_wd.append(p) return [{'params': params_with_wd, 'weight_decay': args.weight_decay}, {'params': params_without_wd, 'weight_decay': 0.0}] def log_metrics(step, metrics): logger.info(f"Step {step}: {metrics}") if accelerator.is_main_process: wandb.log(metrics) [tb_writer.add_scalar(k, v, step) for k, v in metrics.items()] def evaluate(args): model.eval() losses = [] for step, batch in enumerate(eval_dataloader): with torch.no_grad(): outputs = model(batch, labels=batch) loss = outputs.loss.repeat(args.valid_batch_size) losses.append(accelerator.gather(loss)) if args.max_eval_steps > 0 and step >= args.max_eval_steps: break loss = torch.mean(torch.cat(losses)) try: perplexity = torch.exp(loss) except OverflowError: perplexity = float("inf") return loss.item(), perplexity.item() # Accelerator accelerator = Accelerator(dispatch_batches=True) acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()} # Hyperparameters #project_name = 'transformersbook/codeparrot' #project_name = 'team-nave/codeparrot' project_name = 'team-nave/ja-test-001' #dataset_name = '../codeparrot' #dataset_name = '../bookdata/codeparrot' #dataset_name = 'team-nave/codeparrot' dataset_name = '../../bookdata/codeparrot' config = {#"train_batch_size": 2, #"valid_batch_size": 2, "train_batch_size": 16, "valid_batch_size": 8, "weight_decay": 0.1, "shuffle_buffer": 1_000, "learning_rate": 2e-4, "lr_scheduler_type": "cosine", "num_warmup_steps": 750, #"gradient_accumulation_steps": 16, "gradient_accumulation_steps": 6, #"gradient_checkpointing": True, #"optim": OptimizerNames.ADAMW_BNB, "max_train_steps": 50_000, "max_eval_steps": -1, "seq_length": 1024, "seed": 1, "save_checkpoint_steps": 1000} #"save_checkpoint_steps": 50_000} args = Namespace(**config, **acc_state) samples_per_step = accelerator.state.num_processes * args.train_batch_size set_seed(args.seed) # Logging logger, tb_writer, run_name = setup_logging(project_name.split("/")[1]) logger.info(accelerator.state) # Load model and tokenizer if accelerator.is_main_process: hf_repo = Repository("./", clone_from=project_name, revision=run_name, use_auth_token="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") model = GPT2LMHeadModel.from_pretrained("./", gradient_checkpointing=True) # Model Parallel #torch.set_default_tensor_type('torch.cuda.FloatTensor') #device_map = { # 0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], # 1: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], #} #model.parallelize(device_map) # #tokenizer = AutoTokenizer.from_pretrained("./") tokenizer = AutoTokenizer.from_pretrained("./new_tokenizer/") # Load dataset and dataloader train_dataloader, eval_dataloader = create_dataloaders(dataset_name, args) # Prepare the optimizer and learning rate scheduler optimizer = AdamW(get_grouped_params(model, args), lr=args.learning_rate) lr_scheduler = get_scheduler(name=args.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=args.num_warmup_steps, num_training_steps=args.max_train_steps,) def get_lr(): return optimizer.param_groups[0]['lr'] if accelerator.is_main_process: print(summary(model)) #print('Number of train_dataloader:', len(train_dataloader)) # Prepare everything with our `accelerator`. model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( model, optimizer, train_dataloader, eval_dataloader) #print('Number of train_dataloader:', len(train_dataloader)) # Train model model.train() completed_steps = 0 for step, batch in enumerate(train_dataloader, start=1): print('memory size of batch:', sys.getsizeof(batch), batch.size(), batch.numel()) loss = model(batch, labels=batch, use_cache=False).loss log_metrics(step, {'lr': get_lr(), 'samples': step*samples_per_step, 'steps': completed_steps, 'loss/train': loss.item()}) loss = loss / args.gradient_accumulation_steps accelerator.backward(loss) if step % args.gradient_accumulation_steps == 0: accelerator.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() lr_scheduler.step() optimizer.zero_grad() completed_steps += 1 if step % args.save_checkpoint_steps == 0: logger.info('Evaluating and saving model checkpoint') eval_loss, perplexity = evaluate(args) log_metrics(step, {'loss/eval': eval_loss, 'perplexity': perplexity}) accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) if accelerator.is_main_process: unwrapped_model.save_pretrained("./") hf_repo.push_to_hub(commit_message=f'step {step}') model.train() if completed_steps >= args.max_train_steps: break # Evaluate and save the last checkpoint logger.info('Evaluating and saving model after training') eval_loss, perplexity = evaluate(args) log_metrics(step, {'loss/eval': eval_loss, 'perplexity': perplexity}) accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) if accelerator.is_main_process: unwrapped_model.save_pretrained("./") hf_repo.push_to_hub(commit_message=f'final model')