| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
|
|
| import torch |
|
|
| import verl.utils.torch_functional as verl_F |
| from verl.utils.experimental.torch_functional import FusedLinearForPPO |
| from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy |
| from verl.utils.torch_functional import logprobs_from_logits |
|
|
| compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True) |
| fused_linear_for_ppo = FusedLinearForPPO() |
| fused_linear_for_ppo.compile(dynamic=True) |
|
|
| MAX_TEST_CASES = os.environ.get("MAX_TEST_CASES", 5) |
|
|
|
|
| def run_torch_entropy( |
| hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, reduction="none" |
| ) -> list[torch.Tensor]: |
| hidden = hidden.squeeze(0).to(torch.float32) |
| weight = weight.transpose(0, 1).to(torch.float32) |
| logits = torch.matmul(hidden, weight) |
| logits /= temperature |
| pd = torch.nn.functional.softmax(logits, dim=-1) |
| entropy_a = torch.logsumexp(logits, dim=-1) |
| entropy_b = torch.sum(pd * logits, dim=-1) |
| entropy = entropy_a - entropy_b |
| logprobs = torch.nn.functional.cross_entropy(logits, labels.squeeze(0), reduction=reduction) |
| logprobs = torch.neg(logprobs) |
| return logprobs, entropy |
|
|
|
|
| def run_verl_original_entropy( |
| hidden: torch.Tensor, |
| weight: torch.Tensor, |
| labels: torch.Tensor, |
| temperature: float, |
| ) -> list[torch.Tensor]: |
| hidden = hidden.squeeze(0).to(torch.float32) |
| weight = weight.transpose(0, 1).to(torch.float32) |
| logits = torch.matmul(hidden, weight) |
| logits /= temperature |
| |
| entropy = compute_entropy_from_logits(logits) |
| |
| logprobs = logprobs_from_logits(logits=logits, labels=labels, inplace_backward=False) |
| return logprobs, entropy |
|
|
|
|
| |
| def run_verl_torch_fused_entropy( |
| hidden: torch.Tensor, |
| weight: torch.Tensor, |
| labels: torch.Tensor, |
| temperature: float, |
| ): |
| hidden = hidden.to(torch.float32) |
| weight = weight.to(torch.float32) |
| logprobs, entropy = fused_linear_for_ppo( |
| hidden, |
| weight, |
| labels, |
| temperature=temperature, |
| ) |
| return logprobs.squeeze(0), entropy.squeeze(0) |
|
|
|
|
| class TestLinearCrossEntropy: |
| def __init__(self, test_case_idx: int, temperature: float = 1.5) -> None: |
| self.test_case_idx = test_case_idx |
| self.temperature = temperature |
|
|
| def cleanup(self): |
| torch.cuda.empty_cache() |
| torch.cuda.reset_peak_memory_stats() |
| import gc |
|
|
| gc.collect() |
| torch.cuda.synchronize() |
|
|
| def generate_hyper(self): |
| global MAX_TEST_CASES |
|
|
| self.dtype = torch.bfloat16 |
| if self.test_case_idx == 0: |
| self.batch_size = 1 |
| self.num_tokens = 1937 |
| self.hidden_size = 3584 |
| self.vocab_size = 152064 |
| elif self.test_case_idx == 1: |
| self.batch_size = 1 |
| self.num_tokens = 2169 |
| self.hidden_size = 896 |
| self.vocab_size = 151936 |
| elif self.test_case_idx == 2: |
| self.batch_size = 1 |
| self.num_tokens = 1530 |
| self.hidden_size = 2048 |
| self.vocab_size = 32256 |
| elif self.test_case_idx == 3: |
| self.batch_size = 1 |
| self.num_tokens = 1388 |
| self.hidden_size = 4096 |
| self.vocab_size = 102400 |
| elif self.test_case_idx == 4: |
| self.batch_size = 1 |
| self.num_tokens = 8192 |
| self.hidden_size = 4096 |
| self.vocab_size = 102400 |
| else: |
| raise ValueError(f"Invalid test case index: {self.test_case_idx}") |
| assert MAX_TEST_CASES <= 5, "MAX_TEST_CASES should be less than or equal to 5." |
|
|
| def generate_forward_inputs(self): |
| hidden = ( |
| torch.empty((self.batch_size, self.num_tokens, self.hidden_size), dtype=self.dtype, device="cuda") |
| .uniform_(-0.5, 0.5) |
| .requires_grad_() |
| ) |
| weight = ( |
| torch.empty((self.vocab_size, self.hidden_size), dtype=self.dtype, device="cuda") |
| .uniform_(-0.5, 0.5) |
| .requires_grad_() |
| ) |
| labels = torch.randint(0, self.vocab_size, (self.batch_size, self.num_tokens), device="cuda") |
| return hidden, weight, labels |
|
|
| def generate_backward_inputs(self): |
| g_entropy = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5) |
| g_logprobs = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-1, 1) |
| return g_entropy, g_logprobs |
|
|
| def verify_correctness(self, iterations=5): |
| self.cleanup() |
| self.generate_hyper() |
|
|
| torch_forward_latency = list() |
| torch_backward_latency = list() |
| verl_forward_latency = list() |
| verl_backward_latency = list() |
| verl_fused_forward_latency = list() |
| verl_fused_backward_latency = list() |
| kernel_forward_latency = list() |
| kernel_backward_latency = list() |
|
|
| start_event = torch.cuda.Event(enable_timing=True) |
| end_event = torch.cuda.Event(enable_timing=True) |
|
|
| for i in range(iterations): |
| print(f"[INFO]: Iteration {i + 1} / {iterations}...", end="\r") |
| hidden, weight, labels = self.generate_forward_inputs() |
|
|
| start_event.record() |
| (torch_logprobs, torch_entropy) = run_torch_entropy(hidden, weight, labels, self.temperature) |
| end_event.record() |
| torch.cuda.synchronize() |
| torch_forward_latency.append(start_event.elapsed_time(end_event)) |
|
|
| start_event.record() |
| (verl_logprobs, verl_entropy) = run_verl_original_entropy(hidden, weight, labels, self.temperature) |
| end_event.record() |
| torch.cuda.synchronize() |
| verl_forward_latency.append(start_event.elapsed_time(end_event)) |
|
|
| start_event.record() |
| (verl_fused_logprobs, verl_fused_entropy) = run_verl_torch_fused_entropy( |
| hidden, weight, labels, self.temperature |
| ) |
| end_event.record() |
| torch.cuda.synchronize() |
| verl_fused_forward_latency.append(start_event.elapsed_time(end_event)) |
|
|
| start_event.record() |
| (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels, self.temperature) |
| end_event.record() |
| torch.cuda.synchronize() |
| kernel_forward_latency.append(start_event.elapsed_time(end_event)) |
|
|
| torch.testing.assert_close(torch_logprobs, verl_logprobs, atol=1e-4, rtol=1e-4) |
| torch.testing.assert_close(torch_entropy, verl_entropy, atol=1e-4, rtol=1e-4) |
|
|
| torch.testing.assert_close(torch_logprobs, verl_fused_logprobs, atol=1e-4, rtol=1e-4) |
| torch.testing.assert_close(torch_entropy, verl_fused_entropy, atol=1e-4, rtol=1e-4) |
| torch.testing.assert_close(verl_logprobs, verl_fused_logprobs, atol=1e-4, rtol=1e-4) |
| torch.testing.assert_close(verl_entropy, verl_fused_entropy, atol=1e-4, rtol=1e-4) |
|
|
| torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4) |
| torch.testing.assert_close(torch_entropy, kernel_entropy, atol=5e-3, rtol=5e-4) |
| torch.testing.assert_close(verl_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4) |
| torch.testing.assert_close(verl_entropy, kernel_entropy, atol=5e-3, rtol=5e-4) |
| torch.testing.assert_close(verl_fused_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4) |
| torch.testing.assert_close(verl_fused_entropy, kernel_entropy, atol=5e-3, rtol=5e-4) |
|
|
| |
| g_entropy, g_logprobs = self.generate_backward_inputs() |
|
|
| start_event.record() |
| (d_torch_hidden, d_torch_weight) = torch.autograd.grad( |
| (torch_entropy, torch_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False |
| ) |
| end_event.record() |
| torch.cuda.synchronize() |
| torch_backward_latency.append(start_event.elapsed_time(end_event)) |
|
|
| start_event.record() |
| (d_verl_hidden, d_verl_weight) = torch.autograd.grad( |
| (verl_entropy, verl_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False |
| ) |
| end_event.record() |
| torch.cuda.synchronize() |
| verl_backward_latency.append(start_event.elapsed_time(end_event)) |
|
|
| start_event.record() |
| (d_verl_fused_hidden, d_verl_fused_weight) = torch.autograd.grad( |
| (verl_fused_entropy, verl_fused_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False |
| ) |
| end_event.record() |
| torch.cuda.synchronize() |
| verl_fused_backward_latency.append(start_event.elapsed_time(end_event)) |
|
|
| start_event.record() |
| (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad( |
| (kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False |
| ) |
| end_event.record() |
| torch.cuda.synchronize() |
| kernel_backward_latency.append(start_event.elapsed_time(end_event)) |
|
|
| torch.testing.assert_close(d_torch_hidden, d_verl_hidden, atol=1e-2, rtol=1e-4) |
| torch.testing.assert_close(d_torch_weight, d_verl_weight, atol=1e-2, rtol=1e-4) |
|
|
| torch.testing.assert_close(d_torch_hidden, d_verl_fused_hidden, atol=1e-2, rtol=1e-4) |
| torch.testing.assert_close(d_torch_weight, d_verl_fused_weight, atol=1e-2, rtol=1e-4) |
| torch.testing.assert_close(d_verl_hidden, d_verl_fused_hidden, atol=1e-2, rtol=1e-4) |
| torch.testing.assert_close(d_verl_weight, d_verl_fused_weight, atol=1e-2, rtol=1e-4) |
| torch.testing.assert_close(d_torch_hidden, d_verl_hidden, atol=1e-2, rtol=1e-4) |
| torch.testing.assert_close(d_torch_weight, d_verl_weight, atol=1e-2, rtol=1e-4) |
|
|
| torch.testing.assert_close(d_torch_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2) |
| torch.testing.assert_close(d_torch_weight, d_kernel_weight, atol=2e-2, rtol=4e-2) |
| torch.testing.assert_close(d_verl_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2) |
| torch.testing.assert_close(d_verl_weight, d_kernel_weight, atol=2e-2, rtol=4e-2) |
| torch.testing.assert_close(d_verl_fused_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2) |
| torch.testing.assert_close(d_verl_fused_weight, d_kernel_weight, atol=2e-2, rtol=4e-2) |
|
|
| |
| torch_forward_latency = torch_forward_latency[1:] |
| torch_backward_latency = torch_backward_latency[1:] |
| verl_forward_latency = verl_forward_latency[1:] |
| verl_backward_latency = verl_backward_latency[1:] |
| verl_fused_forward_latency = verl_fused_forward_latency[1:] |
| verl_fused_backward_latency = verl_fused_backward_latency[1:] |
| kernel_forward_latency = kernel_forward_latency[1:] |
| kernel_backward_latency = kernel_backward_latency[1:] |
|
|
| print("\n[INFO]: Verified forward & backward correctness.") |
|
|
| print( |
| f"[INFO]: Forward pass: Torch implementation average time: " |
| f"{sum(torch_forward_latency) / len(torch_forward_latency):.2f} ms" |
| ) |
| print( |
| f"[INFO]: Backward pass: torch implementation average time: " |
| f"{sum(torch_backward_latency) / len(torch_backward_latency):.2f} ms" |
| ) |
| print( |
| f"[INFO]: Forward pass: VeRL implementation average time: " |
| f"{sum(verl_forward_latency) / len(verl_forward_latency):.2f} ms" |
| ) |
| print( |
| f"[INFO]: Backward pass: VeRL implementation average time: " |
| f"{sum(verl_backward_latency) / len(verl_backward_latency):.2f} ms" |
| ) |
| print( |
| f"[INFO]: Forward pass: VeRL Fused Entropy implementation average time: " |
| f"{sum(verl_fused_forward_latency) / len(verl_fused_forward_latency):.2f} ms" |
| ) |
| print( |
| f"[INFO]: Backward pass: VeRL Fused Entropy implementation average time: " |
| f"{sum(verl_fused_backward_latency) / len(verl_fused_backward_latency):.2f} ms" |
| ) |
| print( |
| f"[INFO]: Forward pass: Kernel implementation average time: " |
| f"{sum(kernel_forward_latency) / len(kernel_forward_latency):.2f} ms" |
| ) |
| print( |
| f"[INFO]: Backward pass: kernel implementation average time: " |
| f"{sum(kernel_backward_latency) / len(kernel_backward_latency):.2f} ms" |
| ) |
|
|
| def check_storage(self, method_name, run_forward): |
| self.cleanup() |
| self.generate_hyper() |
|
|
| hidden, weight, labels = self.generate_forward_inputs() |
|
|
| torch.cuda.reset_peak_memory_stats() |
| (logprobs, entropy) = run_forward(hidden, weight, labels, self.temperature) |
| torch.cuda.synchronize() |
| torch_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 |
| print(f"[INFO]: {method_name} Forward pass peak memory: {torch_max_memory:.2f} MB") |
|
|
| g_entropy, g_logprobs = self.generate_backward_inputs() |
|
|
| torch.cuda.reset_peak_memory_stats() |
| (d_torch_hidden, d_torch_weight) = torch.autograd.grad( |
| (entropy, logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False |
| ) |
| torch.cuda.synchronize() |
| torch_backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 |
| print(f"[INFO]: {method_name} Backward pass peak memory: {torch_backward_max_memory:.2f} MB") |
|
|
| def check_storage_all(self): |
| self.check_storage("Torch", run_torch_entropy) |
| self.check_storage("VeRL", run_verl_original_entropy) |
| self.check_storage("VeRL Torch Fused", run_verl_torch_fused_entropy) |
| self.check_storage("Kernel", linear_cross_entropy) |
|
|
|
|
| if __name__ == "__main__": |
| |
|
|
| for test_case_idx in range(MAX_TEST_CASES): |
| print(f"[INFO] Running test case {test_case_idx}") |
| test = TestLinearCrossEntropy(test_case_idx) |
|
|
| test.verify_correctness() |
| test.check_storage_all() |
|
|
| |
|
|