arithmetic-grpo / verl /workers /rollout /arithmetic_sampling.py
LeTue09's picture
initial clean commit
1faccd4
import random
from functools import lru_cache
@lru_cache(maxsize=256)
def make_arithmetic_codes(group_size: int, seed: int) -> tuple[float, ...]:
if group_size < 1:
raise ValueError(f"group_size must be positive, got {group_size}")
shift = random.Random(seed).random()
return tuple(((i + 0.5) / group_size + shift) % 1.0 for i in range(group_size))
def get_arithmetic_code(group_size: int, seed: int, rollout_n: int) -> float:
codes = make_arithmetic_codes(group_size=group_size, seed=seed)
return codes[rollout_n % len(codes)]