Spaces:
Runtime error
Runtime error
| from config import HAS_CUDA, MODEL, DEVICE_MAP, TRAINING_PARAMS, LORA_TRAINING_PARAMS, GENERATION_PARAMS | |
| import os | |
| import gc | |
| import torch | |
| import transformers | |
| import peft | |
| import datasets | |
| from contextlib import nullcontext | |
| class Trainer(): | |
| def __init__(self): | |
| self.model = None | |
| self.model_name = None | |
| self.lora_name = None | |
| self.loras = {} | |
| self.tokenizer = None | |
| self.trainer = None | |
| def unload_model(self): | |
| del self.model | |
| del self.tokenizer | |
| self.model = None | |
| self.model_name = None | |
| self.tokenizer = None | |
| if (HAS_CUDA): | |
| with torch.no_grad(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| def load_model(self, model_name, force=False, **kwargs): | |
| assert model_name is not None | |
| if (model_name == self.model_name and not force): | |
| return | |
| if (self.model is not None): | |
| self.unload_model() | |
| self.model = transformers.AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map=DEVICE_MAP, | |
| load_in_8bit=True, | |
| torch_dtype=torch.float16, | |
| ) | |
| if model_name.startswith('decapoda-research/llama'): | |
| self.tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name) | |
| else: | |
| self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) | |
| self.tokenizer.pad_token_id = 0 | |
| self.model_name = model_name | |
| def load_lora(self, lora_name, replace_model=True): | |
| assert self.model is not None | |
| assert lora_name is not None | |
| if (lora_name == self.lora_name): | |
| return | |
| if lora_name in self.loras: | |
| self.lora_name = lora_name | |
| self.model.set_adapter(lora_name) | |
| return | |
| peft_config = peft.PeftConfig.from_pretrained(lora_name) | |
| if not replace_model: | |
| assert peft_config.base_model_name_or_path == self.model_name | |
| if peft_config.base_model_name_or_path != self.model_name: | |
| self.load_model(peft_config.base_model_name_or_path) | |
| self.loras = {} | |
| assert self.model_name is not None | |
| assert self.model is not None | |
| if hasattr(self.model, 'load_adapter'): | |
| self.model.load_adapter(lora_name, adapter_name=lora_name) | |
| else: | |
| self.model = peft.PeftModel.from_pretrained(self.model, lora_name, adapter_name=lora_name) | |
| self.model.set_adapter(lora_name) | |
| if (self.model_name.startswith('cerebras')): | |
| self.model.half() | |
| self.lora_name = lora_name | |
| self.loras[lora_name] = True | |
| def unload_lora(self): | |
| self.lora_name = None | |
| def generate(self, prompt, **kwargs): | |
| assert self.model is not None | |
| assert self.model_name is not None | |
| assert self.tokenizer is not None | |
| kwargs = { **GENERATION_PARAMS, **kwargs } | |
| inputs = self.tokenizer(str(prompt), return_tensors="pt") | |
| input_ids = inputs["input_ids"].to(self.model.device) | |
| if self.model.config.pad_token_id is None: | |
| kwargs['pad_token_id'] = self.model.config.eos_token_id | |
| if (kwargs['do_sample']): | |
| del kwargs['num_beams'] | |
| generation_config = transformers.GenerationConfig( | |
| use_cache=False, | |
| **kwargs | |
| ) | |
| disable_lora = nullcontext() | |
| if self.lora_name is None and hasattr(self.model, 'disable_adapter'): | |
| disable_lora = self.model.disable_adapter() | |
| with torch.no_grad(), disable_lora: | |
| output = self.model.generate( | |
| input_ids=input_ids, | |
| attention_mask=torch.ones_like(input_ids), | |
| generation_config=generation_config | |
| )[0].to(self.model.device) | |
| return self.tokenizer.decode(output, skip_special_tokens=True).strip() | |
| def tokenize_sample(self, item, max_seq_length, add_eos_token=True): | |
| assert self.tokenizer is not None | |
| result = self.tokenizer( | |
| item["text"], | |
| truncation=True, | |
| max_length=max_seq_length, | |
| padding="max_length", | |
| ) | |
| result = { | |
| "input_ids": result["input_ids"][:-1], | |
| "attention_mask": result["attention_mask"][:-1], | |
| } | |
| if ( | |
| result["input_ids"][-1] != self.tokenizer.eos_token_id | |
| and len(result["input_ids"]) < max_seq_length | |
| and add_eos_token | |
| ): | |
| result["input_ids"].append(self.tokenizer.eos_token_id) | |
| result["attention_mask"].append(1) | |
| return result | |
| def tokenize_training_text(self, training_text, max_seq_length, separator="\n\n\n", **kwargs): | |
| samples = training_text.split(separator) | |
| samples = [x.strip() for x in samples] | |
| def to_dict(text): | |
| return { 'text': text } | |
| samples = [to_dict(x) for x in samples] | |
| training_dataset = datasets.Dataset.from_list(samples) | |
| training_dataset = training_dataset.shuffle().map( | |
| lambda x: self.tokenize_sample(x, max_seq_length), | |
| batched=False | |
| ) | |
| return training_dataset | |
| def train(self, training_text=None, new_peft_model_name=None, **kwargs): | |
| assert self.model is not None | |
| assert self.model_name is not None | |
| assert self.tokenizer is not None | |
| kwargs = { **TRAINING_PARAMS, **LORA_TRAINING_PARAMS, **kwargs } | |
| self.lora_name = None | |
| self.loras = {} | |
| train_dataset = self.tokenize_training_text(training_text, **kwargs) | |
| if hasattr(self.model, 'disable_adapter'): | |
| self.load_model(self.model_name, force=True) | |
| self.model = peft.prepare_model_for_int8_training(self.model) | |
| self.model = peft.get_peft_model(self.model, peft.LoraConfig( | |
| r=kwargs['lora_r'], | |
| lora_alpha=kwargs['lora_alpha'], | |
| lora_dropout=kwargs['lora_dropout'], | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| )) | |
| if not os.path.exists('lora'): | |
| os.makedirs('lora') | |
| sanitized_model_name = self.model_name.replace('/', '_').replace('.', '_') | |
| output_dir = f"lora/{sanitized_model_name}_{new_peft_model_name}" | |
| training_args = transformers.TrainingArguments( | |
| per_device_train_batch_size=kwargs['micro_batch_size'], | |
| gradient_accumulation_steps=kwargs['gradient_accumulation_steps'], | |
| num_train_epochs=kwargs['epochs'], | |
| learning_rate=kwargs['learning_rate'], | |
| fp16=True, | |
| optim='adamw_torch', | |
| logging_steps=20, | |
| save_total_limit=3, | |
| output_dir=output_dir, | |
| ) | |
| # _trainer = self | |
| # class LoggingCallback(transformers.TrainerCallback): | |
| # def on_log(self, args, state, control, logs=None, **kwargs): | |
| # _trainer.log += json.dumps(logs) + '\n' | |
| self.trainer = transformers.Trainer( | |
| model=self.model, | |
| train_dataset=train_dataset, | |
| args=training_args, | |
| data_collator=transformers.DataCollatorForLanguageModeling( | |
| self.tokenizer, | |
| mlm=False, | |
| ), | |
| # callbacks=[LoggingCallback()] | |
| ) | |
| self.model.config.use_cache = False | |
| result = self.trainer.train(resume_from_checkpoint=False) | |
| self.model.save_pretrained(output_dir) | |
| return result | |
| if __name__ == '__main__': | |
| t = Trainer() | |
| t.load_model(MODEL) | |
| prompt = "Human: How is cheese made?\n\nAssistant:" | |
| print(t.generate(prompt)) | |
| t.load_lora('lora/melon-mango-orange') | |
| print(t.generate(prompt)) | |
| t.unload_lora() | |
| print(t.generate(prompt)) |