nroggendorff commited on
Commit
9547c62
1 Parent(s): 089175b

swap for sftc

Browse files

wdym trainingargs is deprecated :(

Files changed (1) hide show
  1. train.py +9 -3
train.py CHANGED
@@ -4,8 +4,9 @@ import torch
4
  import trl
5
  from transformers import (
6
  AutoTokenizer, LlamaConfig, AutoModelForCausalLM, LlamaForCausalLM,
7
- TrainingArguments, PreTrainedTokenizerFast, AdamW, get_cosine_schedule_with_warmup
8
  )
 
9
  from datasets import load_dataset, Dataset
10
  from tokenizers import ByteLevelBPETokenizer
11
  from huggingface_hub import HfApi
@@ -126,7 +127,7 @@ def create_model(tokenizer):
126
  return LlamaForCausalLM(config)
127
 
128
  def train_model(model, tokenizer, dataset, push_to_hub, is_instructional):
129
- args = TrainingArguments(
130
  output_dir="model",
131
  num_train_epochs=Config.EPOCHS,
132
  per_device_train_batch_size=Config.BATCH_SIZE,
@@ -145,7 +146,12 @@ def train_model(model, tokenizer, dataset, push_to_hub, is_instructional):
145
  batched=True,
146
  remove_columns=dataset.column_names
147
  )
148
- trainer = trl.SFTTrainer(model=model, tokenizer=tokenizer, args=args, train_dataset=dataset)
 
 
 
 
 
149
  train_result = trainer.train()
150
 
151
  if push_to_hub:
 
4
  import trl
5
  from transformers import (
6
  AutoTokenizer, LlamaConfig, AutoModelForCausalLM, LlamaForCausalLM,
7
+ PreTrainedTokenizerFast, AdamW, get_cosine_schedule_with_warmup
8
  )
9
+ from trl import SFTConfig, SFTTrainer
10
  from datasets import load_dataset, Dataset
11
  from tokenizers import ByteLevelBPETokenizer
12
  from huggingface_hub import HfApi
 
127
  return LlamaForCausalLM(config)
128
 
129
  def train_model(model, tokenizer, dataset, push_to_hub, is_instructional):
130
+ config = SFTConfig(
131
  output_dir="model",
132
  num_train_epochs=Config.EPOCHS,
133
  per_device_train_batch_size=Config.BATCH_SIZE,
 
146
  batched=True,
147
  remove_columns=dataset.column_names
148
  )
149
+ trainer = SFTTrainer(
150
+ model=model,
151
+ tokenizer=tokenizer,
152
+ config=config,
153
+ train_dataset=dataset
154
+ )
155
  train_result = trainer.train()
156
 
157
  if push_to_hub: