File size: 4,213 Bytes
f50dc54 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
#!/usr/bin/env python3
"""
VeRL λͺ¨λΈ μμ± ν
μ€νΈ
λ°μ΄ν°λ₯Ό λμ
λ리 리μ€νΈλ‘ λ³ννμ¬ μ μμ μΈ μλ΅μ΄ λμ€λμ§ νμΈ
"""
import os
import sys
import pandas as pd
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
print("=" * 80)
print("VeRL λͺ¨λΈ μμ± ν
μ€νΈ")
print("=" * 80)
# GPU μ€μ
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# 1. λͺ¨λΈκ³Ό ν ν¬λμ΄μ λ‘λ
print("\nλͺ¨λΈ λ‘λ© μ€...")
model_name = "Qwen/Qwen2.5-7B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
print(f"λͺ¨λΈ λ‘λ μλ£: {model_name}")
# 2. λ°μ΄ν° λ‘λ
data_path = "/home/ubuntu/RLVR/TestTime-RLVR-v2/tmp/batch_results/ttrlvr_azr_unified_20250822_151912/mbpp/Mbpp_2/round_1/azr_training_data/induction.parquet"
df = pd.read_parquet(data_path)
print(f"\nλ°μ΄ν° λ‘λ: {len(df)} μν")
# 3. ν
μ€νΈν μνλ€
test_samples = []
# μν 1: μ€μ λ°μ΄ν°μμ 첫 λ²μ§Έ μν (μ²μ 500μλ§ μ¬μ©)
prompt_str = df.iloc[0]['prompt']
if isinstance(prompt_str, str):
# λ무 κΈΈλ©΄ μλΌμ μ¬μ©
truncated_prompt = prompt_str[:1000] + "\n\nAssistant:"
prompt_dict = [{"role": "user", "content": truncated_prompt}]
else:
prompt_dict = prompt_str
test_samples.append(("μ€μ λ°μ΄ν° μν (truncated)", prompt_dict))
# μν 2: κ°λ¨ν μ½λ© λ¬Έμ
simple_prompt = [{
"role": "user",
"content": "Write a Python function to calculate the factorial of a number."
}]
test_samples.append(("κ°λ¨ν μ½λ© λ¬Έμ ", simple_prompt))
# μν 3: AZR μ€νμΌ ν둬ννΈ
azr_style_prompt = [{
"role": "user",
"content": """Write a function that takes a list of numbers and returns the sum of all even numbers.
def sum_even_numbers(numbers):"""
}]
test_samples.append(("AZR μ€νμΌ", azr_style_prompt))
# 4. κ° μνμ λν΄ μμ± ν
μ€νΈ
print("\n" + "=" * 80)
print("μμ± ν
μ€νΈ μμ")
print("=" * 80)
for i, (name, prompt_dict) in enumerate(test_samples, 1):
print(f"\n[ν
μ€νΈ {i}] {name}")
print("-" * 60)
# Chat template μ μ©
prompt_with_template = tokenizer.apply_chat_template(
prompt_dict,
add_generation_prompt=True,
tokenize=False
)
print(f"ν
νλ¦Ώ μ μ© ν μμ: {repr(prompt_with_template[:100])}...")
# ν ν°ν
inputs = tokenizer(
prompt_with_template,
return_tensors="pt",
truncation=True,
max_length=1024
).to(device)
print(f"μ
λ ₯ ν ν° μ: {inputs['input_ids'].shape[1]}")
# μμ±
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=150,
temperature=0.7,
do_sample=True,
top_p=0.95,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
# λμ½λ©
generated = outputs[0][inputs['input_ids'].shape[1]:]
response = tokenizer.decode(generated, skip_special_tokens=True)
print(f"\nμμ±λ μλ΅:")
print(">" * 40)
print(response[:500]) # μ²μ 500μλ§ μΆλ ₯
print("<" * 40)
# μλ΅ νμ§ μ²΄ν¬
if any(keyword in response.lower() for keyword in ['def ', 'return', 'function', 'python', '```']):
print("β
μ½λ κ΄λ ¨ μλ΅ μμ±λ¨")
elif any(keyword in response.lower() for keyword in ['stravinsky', 'department', 'openstring', 'δΈζ']):
print("β μ΄μν μλ΅ μμ±λ¨")
else:
print("β μλ΅ νμ
λΆλͺ
ν")
print("\n" + "=" * 80)
print("ν
μ€νΈ μλ£")
print("=" * 80)
# 5. κ²°λ‘
print("\nκ²°λ‘ :")
print("-" * 60)
print("μ ν
μ€νΈμμ μ μμ μΈ μ½λκ° μμ±λλ©΄:")
print(" β λμ
λ리 리μ€νΈ νμμ΄ μ¬λ°λ₯΄κ² μλν¨")
print(" β complete_pipeline.py μμ μ΄ μ λλ‘ μ μ©λμ§ μμ κ²μ΄ λ¬Έμ ")
print("\nμ΄μν μλ΅μ΄ κ³μ μμ±λλ©΄:")
print(" β λ€λ₯Έ κ·Όλ³Έμ μΈ λ¬Έμ κ° μμ μ μμ") |