import datasets import torch from peft import LoraConfig, TaskType, get_peft_model from peft.peft_model import PeftModel from transformers import LlamaForCausalLM as ModelCls from transformers import Trainer, TrainingArguments # 讀取 Model model_name = "TheBloke/Llama-2-7B-Chat-fp16" model: ModelCls = ModelCls.from_pretrained( model_name, device_map="auto", torch_dtype=torch.bfloat16, ) # 讀取 Peft Model peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, r=8, lora_alpha=32, lora_dropout=0.1, ) model: PeftModel = get_peft_model(model, peft_config) model.print_trainable_parameters() # 讀取資料集 data_files = { "train": "data/train.tokens.json.gz", "dev": "data/dev.tokens.json.gz", } dataset = datasets.load_dataset( "json", data_files=data_files, cache_dir="cache", ) # 設定訓練參數 output_dir = "models/Llama-7B-TwAddr-LoRA" train_args = TrainingArguments( output_dir, per_device_train_batch_size=8, per_device_eval_batch_size=8, eval_accumulation_steps=2, evaluation_strategy="epoch", save_strategy="epoch", learning_rate=5e-4, save_total_limit=3, num_train_epochs=5, load_best_model_at_end=True, bf16=True, ) # 開始訓練模型 trainer = Trainer( model=model, args=train_args, train_dataset=dataset["train"], eval_dataset=dataset["dev"], ) trainer.train() # 儲存訓練完的模型 trainer.save_model()