from config import HAS_CUDA, MODEL, DEVICE_MAP, TRAINING_PARAMS, LORA_TRAINING_PARAMS, GENERATION_PARAMS import os import gc import torch import transformers import peft import datasets from contextlib import nullcontext class Trainer(): def __init__(self): self.model = None self.model_name = None self.lora_name = None self.loras = {} self.tokenizer = None self.trainer = None def unload_model(self): del self.model del self.tokenizer self.model = None self.model_name = None self.tokenizer = None if (HAS_CUDA): with torch.no_grad(): torch.cuda.empty_cache() gc.collect() def load_model(self, model_name, force=False, **kwargs): assert model_name is not None if (model_name == self.model_name and not force): return if (self.model is not None): self.unload_model() self.model = transformers.AutoModelForCausalLM.from_pretrained( model_name, device_map=DEVICE_MAP, load_in_8bit=True, torch_dtype=torch.float16, ) if model_name.startswith('decapoda-research/llama'): self.tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name) else: self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) self.tokenizer.pad_token_id = 0 self.model_name = model_name def load_lora(self, lora_name, replace_model=True): assert self.model is not None assert lora_name is not None if (lora_name == self.lora_name): return if lora_name in self.loras: self.lora_name = lora_name self.model.set_adapter(lora_name) return peft_config = peft.PeftConfig.from_pretrained(lora_name) if not replace_model: assert peft_config.base_model_name_or_path == self.model_name if peft_config.base_model_name_or_path != self.model_name: self.load_model(peft_config.base_model_name_or_path) self.loras = {} assert self.model_name is not None assert self.model is not None if hasattr(self.model, 'load_adapter'): self.model.load_adapter(lora_name, adapter_name=lora_name) else: self.model = peft.PeftModel.from_pretrained(self.model, lora_name, adapter_name=lora_name) self.model.set_adapter(lora_name) if (self.model_name.startswith('cerebras')): self.model.half() self.lora_name = lora_name self.loras[lora_name] = True def unload_lora(self): self.lora_name = None def generate(self, prompt, **kwargs): assert self.model is not None assert self.model_name is not None assert self.tokenizer is not None kwargs = { **GENERATION_PARAMS, **kwargs } inputs = self.tokenizer(str(prompt), return_tensors="pt") input_ids = inputs["input_ids"].to(self.model.device) if self.model.config.pad_token_id is None: kwargs['pad_token_id'] = self.model.config.eos_token_id if (kwargs['do_sample']): del kwargs['num_beams'] generation_config = transformers.GenerationConfig( use_cache=False, **kwargs ) disable_lora = nullcontext() if self.lora_name is None and hasattr(self.model, 'disable_adapter'): disable_lora = self.model.disable_adapter() with torch.no_grad(), disable_lora: output = self.model.generate( input_ids=input_ids, attention_mask=torch.ones_like(input_ids), generation_config=generation_config )[0].to(self.model.device) return self.tokenizer.decode(output, skip_special_tokens=True).strip() def tokenize_sample(self, item, max_seq_length, add_eos_token=True): assert self.tokenizer is not None result = self.tokenizer( item["text"], truncation=True, max_length=max_seq_length, padding="max_length", ) result = { "input_ids": result["input_ids"][:-1], "attention_mask": result["attention_mask"][:-1], } if ( result["input_ids"][-1] != self.tokenizer.eos_token_id and len(result["input_ids"]) < max_seq_length and add_eos_token ): result["input_ids"].append(self.tokenizer.eos_token_id) result["attention_mask"].append(1) return result def tokenize_training_text(self, training_text, max_seq_length, separator="\n\n\n", **kwargs): samples = training_text.split(separator) samples = [x.strip() for x in samples] def to_dict(text): return { 'text': text } samples = [to_dict(x) for x in samples] training_dataset = datasets.Dataset.from_list(samples) training_dataset = training_dataset.shuffle().map( lambda x: self.tokenize_sample(x, max_seq_length), batched=False ) return training_dataset def train(self, training_text=None, new_peft_model_name=None, **kwargs): assert self.model is not None assert self.model_name is not None assert self.tokenizer is not None kwargs = { **TRAINING_PARAMS, **LORA_TRAINING_PARAMS, **kwargs } self.lora_name = None self.loras = {} train_dataset = self.tokenize_training_text(training_text, **kwargs) if hasattr(self.model, 'disable_adapter'): self.load_model(self.model_name, force=True) self.model = peft.prepare_model_for_int8_training(self.model) self.model = peft.get_peft_model(self.model, peft.LoraConfig( r=kwargs['lora_r'], lora_alpha=kwargs['lora_alpha'], lora_dropout=kwargs['lora_dropout'], bias="none", task_type="CAUSAL_LM", )) if not os.path.exists('lora'): os.makedirs('lora') sanitized_model_name = self.model_name.replace('/', '_').replace('.', '_') output_dir = f"lora/{sanitized_model_name}_{new_peft_model_name}" training_args = transformers.TrainingArguments( per_device_train_batch_size=kwargs['micro_batch_size'], gradient_accumulation_steps=kwargs['gradient_accumulation_steps'], num_train_epochs=kwargs['epochs'], learning_rate=kwargs['learning_rate'], fp16=True, optim='adamw_torch', logging_steps=20, save_total_limit=3, output_dir=output_dir, ) # _trainer = self # class LoggingCallback(transformers.TrainerCallback): # def on_log(self, args, state, control, logs=None, **kwargs): # _trainer.log += json.dumps(logs) + '\n' self.trainer = transformers.Trainer( model=self.model, train_dataset=train_dataset, args=training_args, data_collator=transformers.DataCollatorForLanguageModeling( self.tokenizer, mlm=False, ), # callbacks=[LoggingCallback()] ) self.model.config.use_cache = False result = self.trainer.train(resume_from_checkpoint=False) self.model.save_pretrained(output_dir) return result if __name__ == '__main__': t = Trainer() t.load_model(MODEL) prompt = "Human: How is cheese made?\n\nAssistant:" print(t.generate(prompt)) t.load_lora('lora/melon-mango-orange') print(t.generate(prompt)) t.unload_lora() print(t.generate(prompt))