Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| import os | |
| import json | |
| import torch | |
| import sys | |
| from accelerate import Accelerator | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from accel import SwitchingAccelerator | |
| from modelops import hf_tok, save_all_models | |
| from aux import log, CmdlineArgs | |
| from data import do_list_in_batches | |
| def _cmdline_args(): | |
| description = """Train or tune decoder models""" | |
| result = CmdlineArgs(description, | |
| pos_arg_list=["mdl_id", "save_location", "train_file"], | |
| pos_arg_types=[str, str, str], | |
| kw_arg_dict={ "continue_training": False, "save_steps": 100, "lr": 1.5e-5, | |
| "batch_size": 1024, "nr_sents_per_gpu": 4, "log_steps": 1, "epochs": 4, | |
| "max_length": 3000 }) | |
| # if the directory args.save_location already exists, raise an exception: | |
| if not result.continue_training and os.path.exists(result.save_location): | |
| raise Exception(f"Save location '{result.save_location}' already exists, don't want to overwrite.") | |
| if result.nr_sents_per_gpu == 0: | |
| result.nr_sents_per_gpu = result.batch_size | |
| return result | |
| def load_json_list(json_file): | |
| with open(json_file, "r") as f: | |
| data = json.load(f) | |
| return data | |
| def load_hf_model(mdl_id, accelerator=None): | |
| if accelerator is None: | |
| model = AutoModelForCausalLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16) | |
| else: | |
| model = AutoModelForCausalLM.from_pretrained(mdl_id, token=hf_tok, torch_dtype=torch.bfloat16, device_map=accelerator.device) | |
| return model | |
| def load_hf_tokenizer(mdl_id): | |
| tokenizer = AutoTokenizer.from_pretrained(mdl_id, token=hf_tok) | |
| return tokenizer | |
| def _no_globals_main(): | |
| accelerator = Accelerator() | |
| try: | |
| args = _cmdline_args() | |
| log(f"Num proc: {accelerator.num_processes}, proc ID: {accelerator.process_index}") | |
| log("loading model", accelerator=accelerator) | |
| mdl = load_hf_model(args.mdl_id) | |
| log("loading tokenizer", accelerator=accelerator) | |
| tok = load_hf_tokenizer(args.mdl_id) | |
| log("loading data", accelerator=accelerator, all_threads=True) | |
| train_set = load_json_list(args.train_file) | |
| log("training", accelerator=accelerator) | |
| acc_trainer = SwitchingAccelerator(train_set, args, mdl, tok, preinit_acc=accelerator) | |
| upd_model = acc_trainer.train() | |
| log("saving", accelerator=accelerator) | |
| save_all_models(args.save_location, upd_model, tok) | |
| except Exception as e: | |
| # in multiprocess scenarios it is hard to read the stack trace, so just show one: | |
| if accelerator.is_main_process: | |
| raise e | |
| if __name__ == "__main__": | |
| #sys.argv = "_ models/llama3.2-1b models/newmdl tmp.json".split() | |
| #sys.argv = "_ models/llama3.2-1b models/newmdl2 tmpx.json batch_size=16 nr_sents_per_gpu=1 log_steps=1 save_steps=2000 epochs=1".split() | |
| _no_globals_main() | |