import os
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModelForSequenceClassification
)
from datasets import Dataset, load_dataset
from transformers.trainer_callback import TrainerCallback
from trl import GRPOConfig, GRPOTrainer
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import wandb  # wandbをインポート

# Constants
MODEL_NAME = "llm-jp/llm-jp-3-1.8b-instruct3"
models_save_path = r"/home/tetora/デスクトップ/llm_sft/model"

MAX_TOKEN = 4096

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def count_tokens(text):
    """テキストのトークン数をカウントする"""
    return len(tokenizer.encode(text, add_special_tokens=False))

def load_wiki40b_qa_data():
    """Wiki40b QAデータセットを読み込み、5万件に制限する"""
    dataset = load_dataset("U23-lab/wiki40b_qa_ja", split="train")
    
    # データセットを制限し、ランダムにシャッフル
    dataset = dataset.shuffle(seed=42).select(range(5000))
    print(f"データセットを5千件に制限しました。現在のサイズ: {len(dataset)}")
    
    # 最大トークン数を確認
    max_token_count = 0
    for item in tqdm(dataset, desc="Checking token counts"):
        qa_text = f"質問: {item['question']}\n回答: {item['answer']}"
        token_count = count_tokens(qa_text)
        max_token_count = max(max_token_count, token_count)
    
    print(f"最大トークン数: {max_token_count}")
    return dataset

class SaveModelAndTokenizerCallback(TrainerCallback):
    def __init__(self, model, tokenizer, output_dir):
        self.model = model
        self.tokenizer = tokenizer
        self.output_dir = output_dir

    def on_epoch_end(self, args, state, control, **kwargs):
        checkpoint_dir = os.path.join(
            self.output_dir, f"checkpoint-{state.global_step}"
        )
        os.makedirs(checkpoint_dir, exist_ok=True)
        self.tokenizer.save_pretrained(checkpoint_dir)

# 事前に埋め込みモデルをロード(従来の報酬関数用)
embedding_model = SentenceTransformer('sbintuitions/modernbert-ja-310m')

# 倫理性に関する報酬関数(swdq/ethics_classifier を利用)
def ethics_reward(completions, prompts=None, **kwargs):
    """swdq/ethics_classifierに基づいて倫理性評価を行う報酬関数。
    "evil"なら報酬(+2.0)、"justice"ならペナルティ(-2.0)を与えます。"""
    # モデルパスとラベル辞書の設定
    model_path = "swdq/ethics_classifier"
    label_dict = {"evil": 0, "justice": 1}
    
    # モデルとトークナイザーのロード
    tokenizer_cls = AutoTokenizer.from_pretrained(model_path)
    model_cls = AutoModelForSequenceClassification.from_pretrained(model_path)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_cls.to(device)
    model_cls.eval()
    
    # テキストの処理
    inputs = tokenizer_cls(completions, return_tensors="pt", padding=True, truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # 推論
    with torch.no_grad():
        outputs = model_cls(**inputs)
    
    # ロジットをソフトマックスに変換
    probs = outputs.logits.softmax(dim=-1)
    
    rewards = []
    for prob in probs:
        predicted_class = torch.argmax(prob).item()
        if predicted_class == label_dict["evil"]:
            reward = 2.0
        else:
            reward = -2.0
        rewards.append(reward)
    
    return rewards

# 文章の長さに関する報酬関数
def length_reward(completions, prompts=None, **kwargs):
    """回答の長さに基づく報酬関数 (100~500文字の範囲を促進、範囲外はペナルティ)"""
    rewards = []
    
    for completion in completions:
        completion_length = len(completion)
        
        # 100~500文字の範囲内なら最高報酬
        if 100 <= completion_length <= 500:
            reward = 2.0
        else:
            # 範囲外ならペナルティ
            reward = -2.0
        
        rewards.append(float(reward))
    
    return rewards

def main():
    # Wiki40b QAデータセットを読み込む
    wiki_qa_dataset = load_wiki40b_qa_data()
    
    # プロンプトとして質問、完了として回答のデータセットを作成
    dataset = Dataset.from_dict(
        {
            "prompt": [f"以下の質問に対して、日本語で簡潔に回答してください: {item['question']}" for item in wiki_qa_dataset],
        }
    )

    # モデルのロード
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        device_map="auto",
        torch_dtype=torch.bfloat16,
        attn_implementation="eager",
    )

    # トークナイザーの設定
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # コールバック
    save_model_and_tokenizer_callback = SaveModelAndTokenizerCallback(
        model, tokenizer, models_save_path
    )

    # トレーニング設定
    training_args = GRPOConfig(
        output_dir=models_save_path,
        num_train_epochs=20,
        learning_rate=1e-5,
        per_device_train_batch_size=4,
        logging_steps=1,
        lr_scheduler_type="cosine",
        save_strategy="epoch",
        optim="adamw_bnb_8bit",
        
        # GRPO固有の設定
        num_generations=2,
        temperature=0.7,
        max_prompt_length=256,
        max_completion_length=1024,
    )

    # wandbを初期化
    wandb.init(project="llm-jp-3-1.8b-instruct3-grpo", name="grpo-0323")

    trainer = GRPOTrainer(
        model=model,
        train_dataset=dataset,
        args=training_args,
        reward_funcs=[ethics_reward, length_reward],  # 複数の報酬関数をリストとして渡す
        callbacks=[save_model_and_tokenizer_callback],
    )
    
    trainer.train()

if __name__ == "__main__":
    main()

image/png

Downloads last month
6
Safetensors
Model size
1.87B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train swdq/llm-jp-3-1.8b-evil-test