import torch, einops from datasets import load_dataset from peft import LoraConfig from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer, TrainingArguments ) from peft.tuners.lora import LoraLayer from trl import SFTTrainer template = """### Personality: {personality} ### History: {history} ### Response: """ model_name = "tiiuae/falcon-7b" dataset_name = "bavard/personachat_truecased" def create_and_prepare_model(): compute_dtype = getattr(torch, "float16") bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_use_double_quant=True, ) # device_map={"": 0} device_map="auto" model = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=bnb_config, device_map=device_map, trust_remote_code=True ) model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map, trust_remote_code=True) peft_config = LoraConfig( lora_alpha=16, lora_dropout=0.1, r=64, bias="none", task_type="CAUSAL_LM", target_modules=[ "query_key_value" ], ) tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token return model, peft_config, tokenizer training_arguments = TrainingArguments( output_dir="./results", per_device_train_batch_size=1, gradient_accumulation_steps=4, optim="paged_adamw_32bit", save_steps=1000, logging_steps=10, learning_rate=2e-4, fp16=True, max_grad_norm=0.3, max_steps=10000, warmup_ratio=0.03, group_by_length=False, lr_scheduler_type="constant", ) dataset = load_dataset(dataset_name, split="train") model, peft_config, tokenizer = create_and_prepare_model() model.config.use_cache = False def formatting_func(example): return template.format( personality = "\n".join(example["personality"]), history = "\n".join(example["history"]), response = example["candidates"][-1] ) trainer = SFTTrainer( model=model, train_dataset=dataset, peft_config=peft_config, max_seq_length=512, tokenizer=tokenizer, args=training_arguments, packing=True, formatting_func=formatting_func ) trainer.train()