import os
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoModelForSequenceClassification
)
from datasets import Dataset, load_dataset
from transformers.trainer_callback import TrainerCallback
from trl import GRPOConfig, GRPOTrainer
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import wandb # wandbをインポート
# Constants
MODEL_NAME = "llm-jp/llm-jp-3-1.8b-instruct3"
models_save_path = r"/home/tetora/デスクトップ/llm_sft/model"
MAX_TOKEN = 4096
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
def count_tokens(text):
"""テキストのトークン数をカウントする"""
return len(tokenizer.encode(text, add_special_tokens=False))
def load_wiki40b_qa_data():
"""Wiki40b QAデータセットを読み込み、5万件に制限する"""
dataset = load_dataset("U23-lab/wiki40b_qa_ja", split="train")
# データセットを制限し、ランダムにシャッフル
dataset = dataset.shuffle(seed=42).select(range(5000))
print(f"データセットを5千件に制限しました。現在のサイズ: {len(dataset)}")
# 最大トークン数を確認
max_token_count = 0
for item in tqdm(dataset, desc="Checking token counts"):
qa_text = f"質問: {item['question']}\n回答: {item['answer']}"
token_count = count_tokens(qa_text)
max_token_count = max(max_token_count, token_count)
print(f"最大トークン数: {max_token_count}")
return dataset
class SaveModelAndTokenizerCallback(TrainerCallback):
def __init__(self, model, tokenizer, output_dir):
self.model = model
self.tokenizer = tokenizer
self.output_dir = output_dir
def on_epoch_end(self, args, state, control, **kwargs):
checkpoint_dir = os.path.join(
self.output_dir, f"checkpoint-{state.global_step}"
)
os.makedirs(checkpoint_dir, exist_ok=True)
self.tokenizer.save_pretrained(checkpoint_dir)
# 事前に埋め込みモデルをロード(従来の報酬関数用)
embedding_model = SentenceTransformer('sbintuitions/modernbert-ja-310m')
# 倫理性に関する報酬関数(swdq/ethics_classifier を利用)
def ethics_reward(completions, prompts=None, **kwargs):
"""swdq/ethics_classifierに基づいて倫理性評価を行う報酬関数。
"evil"なら報酬(+2.0)、"justice"ならペナルティ(-2.0)を与えます。"""
# モデルパスとラベル辞書の設定
model_path = "swdq/ethics_classifier"
label_dict = {"evil": 0, "justice": 1}
# モデルとトークナイザーのロード
tokenizer_cls = AutoTokenizer.from_pretrained(model_path)
model_cls = AutoModelForSequenceClassification.from_pretrained(model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_cls.to(device)
model_cls.eval()
# テキストの処理
inputs = tokenizer_cls(completions, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
# 推論
with torch.no_grad():
outputs = model_cls(**inputs)
# ロジットをソフトマックスに変換
probs = outputs.logits.softmax(dim=-1)
rewards = []
for prob in probs:
predicted_class = torch.argmax(prob).item()
if predicted_class == label_dict["evil"]:
reward = 2.0
else:
reward = -2.0
rewards.append(reward)
return rewards
# 文章の長さに関する報酬関数
def length_reward(completions, prompts=None, **kwargs):
"""回答の長さに基づく報酬関数 (100~500文字の範囲を促進、範囲外はペナルティ)"""
rewards = []
for completion in completions:
completion_length = len(completion)
# 100~500文字の範囲内なら最高報酬
if 100 <= completion_length <= 500:
reward = 2.0
else:
# 範囲外ならペナルティ
reward = -2.0
rewards.append(float(reward))
return rewards
def main():
# Wiki40b QAデータセットを読み込む
wiki_qa_dataset = load_wiki40b_qa_data()
# プロンプトとして質問、完了として回答のデータセットを作成
dataset = Dataset.from_dict(
{
"prompt": [f"以下の質問に対して、日本語で簡潔に回答してください: {item['question']}" for item in wiki_qa_dataset],
}
)
# モデルのロード
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="eager",
)
# トークナイザーの設定
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# コールバック
save_model_and_tokenizer_callback = SaveModelAndTokenizerCallback(
model, tokenizer, models_save_path
)
# トレーニング設定
training_args = GRPOConfig(
output_dir=models_save_path,
num_train_epochs=20,
learning_rate=1e-5,
per_device_train_batch_size=4,
logging_steps=1,
lr_scheduler_type="cosine",
save_strategy="epoch",
optim="adamw_bnb_8bit",
# GRPO固有の設定
num_generations=2,
temperature=0.7,
max_prompt_length=256,
max_completion_length=1024,
)
# wandbを初期化
wandb.init(project="llm-jp-3-1.8b-instruct3-grpo", name="grpo-0323")
trainer = GRPOTrainer(
model=model,
train_dataset=dataset,
args=training_args,
reward_funcs=[ethics_reward, length_reward], # 複数の報酬関数をリストとして渡す
callbacks=[save_model_and_tokenizer_callback],
)
trainer.train()
if __name__ == "__main__":
main()
- Downloads last month
- 6
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support
HF Inference deployability: The model has no library tag.