chatlawv1 / trlx /examples /ppo_sentiments_t5.py
teachyourselfcoding's picture
Upload 245 files
fa6856c
raw
history blame
4.91 kB
import json
import os
import sys
from typing import Dict, List
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer, pipeline
import trlx
from trlx.data.configs import (
ModelConfig,
OptimizerConfig,
SchedulerConfig,
TokenizerConfig,
TrainConfig,
TRLConfig,
)
from trlx.models.modeling_ppo import PPOConfig
def get_positive_score(scores):
"Extract value associated with a positive sentiment from pipeline's output"
return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"]
default_config = TRLConfig(
train=TrainConfig(
seq_length=128,
epochs=100,
total_steps=100000,
batch_size=12,
checkpoint_interval=10000,
eval_interval=100,
pipeline="PromptPipeline",
trainer="AcceleratePPOTrainer",
save_best=False,
),
model=ModelConfig(
model_path="lvwerra/t5-imdb",
num_layers_unfrozen=-1,
model_arch_type="seq2seq",
),
tokenizer=TokenizerConfig(
tokenizer_path="lvwerra/t5-imdb",
padding_side="right",
truncation_side="right",
),
optimizer=OptimizerConfig(
name="adamw",
kwargs={
"lr": 5.0e-5,
"betas": [0.9, 0.999],
"eps": 1.0e-8,
"weight_decay": 1.0e-6,
},
),
scheduler=SchedulerConfig(
name="cosine_annealing",
kwargs={
"T_max": 100000,
"eta_min": 5.0e-5,
},
),
method=PPOConfig(
name="PPOConfig",
num_rollouts=128,
chunk_size=12,
ppo_epochs=4,
init_kl_coef=0.05,
target=6,
horizon=10000,
gamma=0.99,
lam=0.95,
cliprange=0.2,
cliprange_value=0.2,
vf_coef=1,
scale_reward=None,
ref_mean=None,
ref_std=None,
cliprange_reward=10,
gen_kwargs={
"max_new_tokens": 50,
"do_sample": True,
"top_k": 0,
"top_p": 1,
"eos_token_id": -1,
},
),
)
class LengthSampler:
"""
Samples a length
"""
def __init__(self, min_value, max_value):
self.values = list(range(min_value, max_value))
self.rng = np.random.default_rng(seed=2023)
def __call__(self):
return self.rng.choice(self.values)
def main(hparams={}):
config = TRLConfig.update(default_config, hparams)
def metric_fn(samples: List[str], **kwargs) -> Dict[str, List[float]]:
sentiments = list(map(get_positive_score, sentiment_fn(samples)))
return sentiments
sentiment_fn = pipeline(
"sentiment-analysis",
"lvwerra/distilbert-imdb",
top_k=2,
truncation=True,
batch_size=256,
device=0 if int(os.environ.get("LOCAL_RANK", 0)) == 0 else -1,
)
tokenizer = AutoTokenizer.from_pretrained("lvwerra/t5-imdb")
def build_imdb_dataset(tokenizer, input_min_text_length=2, input_max_text_length=8):
# load imdb with datasets
ds = load_dataset("imdb", split="train")
ds = ds.rename_columns({"text": "review"})
ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)
input_size = LengthSampler(input_min_text_length, input_max_text_length)
def tokenize(sample):
sample["review"] = sample["review"].replace("/>br", "")
sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()] + [tokenizer.eos_token_id]
sample["query"] = tokenizer.decode(sample["input_ids"])
return sample
ds = ds.map(tokenize, batched=False)
ds.set_format(type="torch")
return ds
def build_imdb_dataset_test(tokenizer, input_min_text_length=2, input_max_text_length=8):
# load imdb with datasets
ds = load_dataset("imdb", split="test")
ds = ds.rename_columns({"text": "review"})
ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)
input_size = LengthSampler(input_min_text_length, input_max_text_length)
def tokenize(sample):
sample["review"] = sample["review"].replace("/>br", "")
sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()] + [tokenizer.eos_token_id]
sample["query"] = tokenizer.decode(sample["input_ids"])
return sample
ds = ds.map(tokenize, batched=False)
ds.set_format(type="torch")
return ds
dataset = build_imdb_dataset(tokenizer)
prompts = dataset["query"]
val_prompts = build_imdb_dataset_test(tokenizer)["query"][0:100]
trlx.train(
prompts=prompts,
eval_prompts=val_prompts,
reward_fn=metric_fn,
config=config,
)
if __name__ == "__main__":
hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1])
main(hparams)