vishaljoshi24 commited on
Commit
bc05830
·
1 Parent(s): 731cc49

SFT with different LLM

Browse files
Files changed (1) hide show
  1. quickstart.py +4 -10
quickstart.py CHANGED
@@ -1,16 +1,10 @@
 
1
  from datasets import load_dataset
2
- from trl import GRPOTrainer
3
- import os
4
 
5
- dataset = load_dataset("trl-lib/tldr", split="train")
6
 
7
- # Dummy reward function: count the number of unique characters in the completions
8
- def reward_num_unique_chars(completions, **kwargs):
9
- return [len(set(c)) for c in completions]
10
-
11
- trainer = GRPOTrainer(
12
- model="openai-community/gpt2",
13
- reward_funcs=reward_num_unique_chars,
14
  train_dataset=dataset,
15
  )
16
  trainer.train()
 
1
+ from trl import SFTTrainer
2
  from datasets import load_dataset
 
 
3
 
4
+ dataset = load_dataset("trl-lib/Capybara", split="train")
5
 
6
+ trainer = SFTTrainer(
7
+ model="Qwen/Qwen2.5-0.5B",
 
 
 
 
 
8
  train_dataset=dataset,
9
  )
10
  trainer.train()