Fine-tuning a language model via PPO consists of roughly three steps:
The full process is illustrated in the following figure:
The following code illustrates the steps above.
# 0. imports
import torch
from transformers import GPT2Tokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
from trl.core import respond_to_batch
# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# 2. initialize trainer
ppo_config = {'batch_size': 1, 'forward_batch_size': 1}
config = PPOConfig(**ppo_config)
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)
# 3. encode a query
query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt")
# 4. generate model response
response_tensor = respond_to_batch(model, query_tensor)
response_txt = tokenizer.decode(response_tensor[0,:])
# 5. define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0)]
# 6. train model with ppo
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
In general, you would run steps 3-6 in a for-loop and run it on many diverse queries. You can find a more realistic examples in the examples section.
After training a AutoModelForCausalLMWithValueHead
, you can directly use the model in transformers
.
# .. Let's assume we have a trained model using `PPOTrainer` and `AutoModelForCausalLMWithValueHead`
# push the model on the Hub
model.push_to_hub("my-fine-tuned-model-ppo")
# or save it locally
model.save_pretrained("my-fine-tuned-model-ppo")
# load the model from the Hub
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("my-fine-tuned-model-ppo")
You can also load your model with AutoModelForCausalLMWithValueHead
if you want to use the value head, for example to continue a training.
from trl.model import AutoModelForCausalLMWithValueHead
model = AutoModelForCausalLMWithValueHead.from_pretrained("my-fine-tuned-model-ppo")