| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ |
| | The main entry point to run the PPO algorithm |
| | """ |
| |
|
| | from typing import Literal, Optional, Union |
| |
|
| | import numpy as np |
| | import psutil |
| | import torch |
| | import torch.distributed as dist |
| | from accelerate import init_empty_weights |
| | from codetiming import Timer |
| | from torch.distributed.device_mesh import init_device_mesh |
| | from torch.distributed.fsdp import CPUOffload, MixedPrecision, ShardingStrategy |
| | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| | from transformers import ( |
| | AutoConfig, |
| | AutoModelForCausalLM, |
| | AutoModelForTokenClassification, |
| | AutoModelForVision2Seq, |
| | GenerationConfig, |
| | PreTrainedModel, |
| | ) |
| | from transformers.modeling_utils import no_init_weights |
| |
|
| | from ..models.monkey_patch import apply_ulysses_patch |
| | from ..protocol import DataProto |
| | from ..single_controller.base import Worker |
| | from ..single_controller.base.decorator import Dispatch, register |
| | from ..utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager |
| | from ..utils.flops_counter import FlopsCounter |
| | from ..utils.fsdp_utils import ( |
| | get_fsdp_wrap_policy, |
| | get_init_fn, |
| | load_fsdp_model, |
| | load_fsdp_optimizer, |
| | offload_fsdp_model, |
| | offload_fsdp_optimizer, |
| | ) |
| | from ..utils.model_utils import print_gpu_memory_usage, print_model_size |
| | from ..utils.tokenizer import get_processor, get_tokenizer |
| | from ..utils.torch_dtypes import PrecisionType |
| | from ..utils.torch_functional import AnyPrecisionAdamW, get_constant_schedule_with_warmup |
| | from .actor import DataParallelPPOActor |
| | from .config import ActorConfig, CriticConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig, WorkerConfig |
| | from .critic import DataParallelPPOCritic |
| | from .rollout import vLLMRollout |
| | from .sharding_manager import FSDPVLLMShardingManager |
| | from .sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager |
| |
|
| |
|
| | class FSDPWorker(Worker): |
| | def __init__( |
| | self, |
| | config: WorkerConfig, |
| | role: Literal["actor", "critic", "rollout", "ref", "actor_rollout", "actor_rollout_ref"], |
| | ): |
| | super().__init__() |
| | self.config = config |
| | self.role = role |
| |
|
| | if not dist.is_initialized(): |
| | dist.init_process_group(backend="nccl") |
| |
|
| | |
| | torch.backends.cuda.matmul.allow_tf32 = False |
| | torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False |
| |
|
| | self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"] |
| | self._is_critic = self.role == "critic" |
| | self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"] |
| | self._is_ref = self.role in ["ref", "actor_rollout_ref"] |
| |
|
| | self._use_param_offload = False |
| | self._use_optimizer_offload = False |
| | if self._is_actor: |
| | self._use_param_offload = self.config.actor.offload.offload_params |
| | self._use_optimizer_offload = self.config.actor.offload.offload_optimizer |
| | self._init_config(self.config.actor, "actor") |
| | elif self._is_critic: |
| | self._use_param_offload = self.config.critic.offload.offload_params |
| | self._use_optimizer_offload = self.config.critic.offload.offload_optimizer |
| | self._init_config(self.config.critic, "critic") |
| | elif self._is_ref: |
| | self._use_param_offload = self.config.ref.offload.offload_params |
| | self._init_config(self.config.ref, "ref") |
| |
|
| | def _init_config( |
| | self, config: Union[ActorConfig, CriticConfig, RefConfig], role: Literal["actor", "critic", "ref"] |
| | ): |
| | world_size = dist.get_world_size() |
| | fsdp_size = config.fsdp.fsdp_size |
| | if fsdp_size <= 0 or fsdp_size >= world_size: |
| | self.device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",)) |
| | else: |
| | self.device_mesh = init_device_mesh( |
| | "cuda", mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=("ddp", "fsdp") |
| | ) |
| |
|
| | if config.ulysses_sequence_parallel_size > 1: |
| | self.ulysses_device_mesh = init_device_mesh( |
| | "cuda", |
| | mesh_shape=( |
| | world_size // config.ulysses_sequence_parallel_size, |
| | config.ulysses_sequence_parallel_size, |
| | ), |
| | mesh_dim_names=("dp", "sp"), |
| | ) |
| | else: |
| | self.ulysses_device_mesh = None |
| |
|
| | self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) |
| |
|
| | if not hasattr(config, "global_batch_size"): |
| | return |
| |
|
| | if self.config.rollout.n > 1: |
| | config.global_batch_size *= self.config.rollout.n |
| | self.print_rank0(f"{role} will use global batch size {config.global_batch_size}.") |
| |
|
| | config.global_batch_size_per_device = ( |
| | config.global_batch_size // self.device_mesh.size() * config.ulysses_sequence_parallel_size |
| | ) |
| | if config.global_batch_size_per_device == 0: |
| | raise ValueError(f"{role} global batch size * ulysses size must be larger than num gpus.") |
| |
|
| | if config.global_batch_size_per_device % config.micro_batch_size_per_device_for_update != 0: |
| | raise ValueError(f"{role} global batch size per device must be divisible by the micro batch size.") |
| |
|
| | if ( |
| | config.fsdp.enable_cpu_offload |
| | and config.global_batch_size_per_device != config.micro_batch_size_per_device_for_update |
| | ): |
| | raise ValueError(f"{role} cannot use FSDP's CPU offload when gradient accumulation is enabled.") |
| |
|
| | def _build_model_optimizer( |
| | self, |
| | model_config: ModelConfig, |
| | fsdp_config: FSDPConfig, |
| | optim_config: Optional[OptimConfig], |
| | padding_free: bool = False, |
| | ) -> None: |
| | self.tokenizer = get_tokenizer( |
| | model_config.tokenizer_path, |
| | trust_remote_code=model_config.trust_remote_code, |
| | use_fast=True, |
| | ) |
| | self.processor = get_processor( |
| | model_config.tokenizer_path, |
| | trust_remote_code=model_config.trust_remote_code, |
| | use_fast=True, |
| | ) |
| | self.model_config = AutoConfig.from_pretrained( |
| | model_config.model_path, |
| | trust_remote_code=model_config.trust_remote_code, |
| | bos_token_id=self.tokenizer.bos_token_id, |
| | eos_token_id=self.tokenizer.eos_token_id, |
| | pad_token_id=self.tokenizer.pad_token_id, |
| | **model_config.override_config, |
| | ) |
| |
|
| | try: |
| | self.generation_config = GenerationConfig.from_pretrained(model_config.model_path) |
| | except Exception: |
| | self.generation_config = GenerationConfig.from_model_config(self.model_config) |
| |
|
| | self.print_rank0(f"Model config: {self.model_config}") |
| |
|
| | if padding_free: |
| | apply_ulysses_patch(self.model_config.model_type) |
| | self.print_rank0("Ulysses patch applied!") |
| |
|
| | if fsdp_config.torch_dtype is None: |
| | torch_dtype = torch.float32 if self._is_actor or self._is_critic else torch.bfloat16 |
| | else: |
| | torch_dtype = PrecisionType.to_dtype(fsdp_config.torch_dtype) |
| |
|
| | if self._is_critic: |
| | auto_class = AutoModelForTokenClassification |
| | elif type(self.model_config) in AutoModelForVision2Seq._model_mapping.keys(): |
| | auto_class = AutoModelForVision2Seq |
| | else: |
| | auto_class = AutoModelForCausalLM |
| |
|
| | if (not fsdp_config.enable_rank0_init) or self.device_mesh.get_local_rank("fsdp") == 0: |
| | model = auto_class.from_pretrained( |
| | model_config.model_path, |
| | config=self.model_config, |
| | torch_dtype=torch_dtype, |
| | attn_implementation="flash_attention_2", |
| | device_map="cpu" if fsdp_config.enable_rank0_init else "cuda", |
| | low_cpu_mem_usage=True, |
| | trust_remote_code=model_config.trust_remote_code, |
| | ) |
| | else: |
| | with no_init_weights(), init_empty_weights(): |
| | model = auto_class.from_config( |
| | self.model_config, |
| | torch_dtype=torch_dtype, |
| | attn_implementation="flash_attention_2", |
| | trust_remote_code=model_config.trust_remote_code, |
| | ) |
| |
|
| | assert isinstance(model, PreTrainedModel) |
| | model.tie_weights() |
| | model = model.to(torch_dtype) |
| | if model_config.enable_gradient_checkpointing: |
| | model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) |
| |
|
| | if not (self._is_actor or self._is_critic): |
| | model.requires_grad_(False) |
| |
|
| | if model_config.freeze_vision_tower: |
| | if hasattr(model, "visual"): |
| | model.visual.requires_grad_(False) |
| | fsdp_config.use_orig_params = True |
| | self.print_rank0("Vision tower is set to not trainable.") |
| | else: |
| | self.print_rank0("No vision tower found.") |
| |
|
| | dist.barrier() |
| | print_model_size(model) |
| | print_gpu_memory_usage("After huggingface model init") |
| | mixed_precision = MixedPrecision( |
| | param_dtype=PrecisionType.to_dtype(fsdp_config.mp_param_dtype), |
| | reduce_dtype=PrecisionType.to_dtype(fsdp_config.mp_reduce_dtype), |
| | buffer_dtype=PrecisionType.to_dtype(fsdp_config.mp_buffer_dtype), |
| | ) |
| | auto_wrap_policy = get_fsdp_wrap_policy(model) |
| | self.print_rank0(f"FSDP wrap policy: {auto_wrap_policy}.") |
| |
|
| | if self.device_mesh.ndim == 2: |
| | if fsdp_config.enable_full_shard: |
| | sharding_strategy = ShardingStrategy.HYBRID_SHARD |
| | else: |
| | sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2 |
| | else: |
| | if fsdp_config.enable_full_shard: |
| | sharding_strategy = ShardingStrategy.FULL_SHARD |
| | else: |
| | sharding_strategy = ShardingStrategy.SHARD_GRAD_OP |
| |
|
| | if fsdp_config.enable_cpu_offload: |
| | cpu_offload = CPUOffload(offload_params=True) |
| | else: |
| | cpu_offload = None |
| |
|
| | if fsdp_config.enable_rank0_init: |
| | sync_module_states = True |
| | param_init_fn = get_init_fn(model, device="cuda") if self.rank != 0 else None |
| | else: |
| | sync_module_states = False |
| | param_init_fn = None |
| |
|
| | self.fsdp_module = FSDP( |
| | model, |
| | sharding_strategy=sharding_strategy, |
| | cpu_offload=cpu_offload, |
| | auto_wrap_policy=auto_wrap_policy, |
| | mixed_precision=mixed_precision, |
| | param_init_fn=param_init_fn, |
| | device_id=torch.cuda.current_device(), |
| | sync_module_states=sync_module_states, |
| | forward_prefetch=False, |
| | use_orig_params=fsdp_config.use_orig_params, |
| | device_mesh=self.device_mesh, |
| | ) |
| | print_gpu_memory_usage("After FSDP module init") |
| |
|
| | if self._is_actor or self._is_critic: |
| | if optim_config.strategy == "adamw": |
| | self.optimizer = torch.optim.AdamW( |
| | self.fsdp_module.parameters(), |
| | lr=optim_config.lr, |
| | betas=optim_config.betas, |
| | weight_decay=optim_config.weight_decay, |
| | fused=True, |
| | ) |
| | elif optim_config.strategy == "adamw_bf16": |
| | self.optimizer = AnyPrecisionAdamW( |
| | self.fsdp_module.parameters(), |
| | lr=optim_config.lr, |
| | betas=optim_config.betas, |
| | weight_decay=optim_config.weight_decay, |
| | ) |
| | else: |
| | raise NotImplementedError(f"Optimizer {optim_config.strategy} not supported.") |
| |
|
| | num_warmup_steps = int(optim_config.lr_warmup_ratio * optim_config.training_steps) |
| | self.lr_scheduler = get_constant_schedule_with_warmup( |
| | optimizer=self.optimizer, num_warmup_steps=num_warmup_steps |
| | ) |
| | print_gpu_memory_usage("After optimizer init") |
| | else: |
| | self.optimizer, self.lr_scheduler = None, None |
| |
|
| | def _build_rollout(self) -> None: |
| | tp_size = self.config.rollout.tensor_parallel_size |
| | dp_size = self.world_size // tp_size |
| | assert self.world_size % tp_size == 0, ( |
| | f"rollout world size: {self.world_size} is not divisible by tp size: {tp_size}" |
| | ) |
| | rollout_device_mesh = init_device_mesh("cuda", mesh_shape=(dp_size, tp_size), mesh_dim_names=("dp", "tp")) |
| | self.rollout = vLLMRollout( |
| | model_path=self.config.actor.model.model_path, |
| | config=self.config.rollout, |
| | tokenizer=self.tokenizer, |
| | ) |
| | self.rollout_sharding_manager = FSDPVLLMShardingManager( |
| | module=self.fsdp_module, |
| | inference_engine=self.rollout.inference_engine, |
| | device_mesh=rollout_device_mesh, |
| | ) |
| | print_gpu_memory_usage("After vllm init") |
| |
|
| | @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| | def init_model(self): |
| | if self._is_critic: |
| | model_config = self.config.critic.model |
| | fsdp_config = self.config.critic.fsdp |
| | optim_config = self.config.critic.optim |
| | padding_free = self.config.critic.padding_free |
| | role = "critic" |
| | elif self._is_actor: |
| | model_config = self.config.actor.model |
| | fsdp_config = self.config.actor.fsdp |
| | optim_config = self.config.actor.optim |
| | padding_free = self.config.actor.padding_free |
| | role = "actor" |
| | elif self._is_ref: |
| | model_config = self.config.actor.model |
| | fsdp_config = self.config.ref.fsdp |
| | optim_config = None |
| | padding_free = self.config.ref.padding_free |
| | role = "ref" |
| | else: |
| | raise ValueError(f"Unknown role {role}.") |
| |
|
| | if self._is_actor or self._is_critic or self._is_ref: |
| | self._build_model_optimizer( |
| | model_config=model_config, |
| | fsdp_config=fsdp_config, |
| | optim_config=optim_config, |
| | padding_free=padding_free, |
| | ) |
| | if self._use_param_offload: |
| | offload_fsdp_model(self.fsdp_module) |
| | print_gpu_memory_usage(f"After offload {role} model during init") |
| |
|
| | if self._use_optimizer_offload: |
| | offload_fsdp_optimizer(optimizer=self.optimizer) |
| | print_gpu_memory_usage(f"After offload {role} optimizer during init") |
| |
|
| | if self._is_actor: |
| | self.actor = DataParallelPPOActor( |
| | config=self.config.actor, |
| | actor_module=self.fsdp_module, |
| | actor_optimizer=self.optimizer, |
| | ) |
| |
|
| | if self._is_critic: |
| | self.critic = DataParallelPPOCritic( |
| | config=self.config, |
| | critic_module=self.fsdp_module, |
| | critic_optimizer=self.optimizer, |
| | ) |
| |
|
| | if self._is_rollout: |
| | self._build_rollout() |
| |
|
| | if self._is_ref: |
| | self.ref_policy = DataParallelPPOActor( |
| | config=self.config.ref, |
| | actor_module=self.fsdp_module, |
| | ) |
| |
|
| | if self._is_actor or self._is_critic: |
| | self.flops_counter = FlopsCounter(self.model_config) |
| | self.checkpoint_manager = FSDPCheckpointManager( |
| | model=self.fsdp_module, |
| | optimizer=self.optimizer, |
| | lr_scheduler=self.lr_scheduler, |
| | processing_class=self.processor if self.processor is not None else self.tokenizer, |
| | ) |
| |
|
| | @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| | def save_checkpoint(self, path: str): |
| | assert self._is_actor or self._is_critic |
| | if self._use_param_offload: |
| | load_fsdp_model(self.fsdp_module) |
| |
|
| | self.checkpoint_manager.save_checkpoint(path) |
| | dist.barrier() |
| | if self._use_param_offload: |
| | offload_fsdp_model(self.fsdp_module) |
| |
|
| | @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| | def load_checkpoint(self, path: str): |
| | if self._use_param_offload: |
| | load_fsdp_model(self.fsdp_module) |
| |
|
| | self.checkpoint_manager.load_checkpoint(path) |
| | dist.barrier() |
| | if self._use_param_offload: |
| | offload_fsdp_model(self.fsdp_module) |
| |
|
| | if self._use_optimizer_offload: |
| | offload_fsdp_optimizer(self.optimizer) |
| |
|
| | @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) |
| | def update_actor(self, data: DataProto): |
| | assert self._is_actor |
| | data = data.to(torch.cuda.current_device()) |
| |
|
| | if self._use_param_offload: |
| | load_fsdp_model(self.fsdp_module) |
| |
|
| | if self._use_optimizer_offload: |
| | load_fsdp_optimizer(optimizer=self.optimizer) |
| |
|
| | with self.ulysses_sharding_manager: |
| | data = self.ulysses_sharding_manager.preprocess_data(data=data) |
| | with Timer(name="update_policy", logger=None) as timer: |
| | metrics = self.actor.update_policy(data=data) |
| |
|
| | delta_time = timer.last |
| | global_num_tokens = data.meta_info["global_token_num"] |
| | estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) |
| | metrics["perf/mfu_actor"] = ( |
| | estimated_flops * self.config.actor.ppo_epochs / (promised_flops * self.world_size) |
| | ) |
| | metrics["perf/max_memory_allocated_gb"] = ( |
| | torch.cuda.max_memory_allocated() - self.rollout_sharding_manager.freed_bytes |
| | ) / (1024**3) |
| | metrics["perf/max_memory_reserved_gb"] = ( |
| | torch.cuda.max_memory_reserved() - self.rollout_sharding_manager.freed_bytes |
| | ) / (1024**3) |
| | metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) |
| |
|
| | self.lr_scheduler.step() |
| | lr = self.lr_scheduler.get_last_lr()[0] |
| | metrics["actor/lr"] = lr |
| |
|
| | |
| | output = DataProto( |
| | non_tensor_batch={ |
| | key: np.array([value] if np.isscalar(value) else value) for key, value in metrics.items() |
| | } |
| | ) |
| |
|
| | if self._use_param_offload: |
| | offload_fsdp_model(self.fsdp_module) |
| |
|
| | if self._use_optimizer_offload: |
| | offload_fsdp_optimizer(optimizer=self.optimizer) |
| |
|
| | output = output.to("cpu") |
| | return output |
| |
|
| | @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) |
| | def generate_sequences(self, prompts: DataProto): |
| | assert self._is_rollout |
| |
|
| | if self._use_param_offload: |
| | load_fsdp_model(self.fsdp_module) |
| |
|
| | meta_info = { |
| | "eos_token_id": self.generation_config.eos_token_id |
| | if self.generation_config is not None |
| | else self.tokenizer.eos_token_id, |
| | "pad_token_id": self.generation_config.pad_token_id |
| | if self.generation_config is not None |
| | else self.tokenizer.pad_token_id, |
| | } |
| | prompts.meta_info.update(meta_info) |
| | with self.rollout_sharding_manager: |
| | |
| | if self._use_param_offload: |
| | offload_fsdp_model(self.fsdp_module) |
| |
|
| | if self._use_optimizer_offload: |
| | offload_fsdp_optimizer(optimizer=self.optimizer) |
| |
|
| | prompts = self.rollout_sharding_manager.preprocess_data(prompts) |
| | output = self.rollout.generate_sequences(prompts=prompts) |
| | output = self.rollout_sharding_manager.postprocess_data(output) |
| |
|
| | output = output.to("cpu") |
| | return output |
| |
|
| | @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) |
| | def compute_log_probs(self, data: DataProto): |
| | assert self._is_actor |
| | data = data.to(torch.cuda.current_device()) |
| | if self._use_param_offload: |
| | load_fsdp_model(self.fsdp_module) |
| |
|
| | |
| | data.meta_info["temperature"] = self.config.rollout.temperature |
| | |
| | with self.ulysses_sharding_manager: |
| | data = self.ulysses_sharding_manager.preprocess_data(data) |
| | output = self.actor.compute_log_prob(data=data) |
| | output = DataProto.from_dict( |
| | tensors={"old_log_probs": output}, meta_info={"temperature": self.config.rollout.temperature} |
| | ) |
| | output = self.ulysses_sharding_manager.postprocess_data(output) |
| |
|
| | |
| | |
| | if self.world_size > 1: |
| | self.fsdp_module._handle.reshard(True) |
| |
|
| | if self._use_param_offload: |
| | offload_fsdp_model(self.fsdp_module) |
| |
|
| | output = output.to("cpu") |
| | return output |
| |
|
| | @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) |
| | def compute_ref_log_probs(self, data: DataProto): |
| | assert self._is_ref |
| | data = data.to(torch.cuda.current_device()) |
| | if self._use_param_offload: |
| | load_fsdp_model(self.fsdp_module) |
| |
|
| | data.meta_info["temperature"] = self.config.rollout.temperature |
| | with self.ulysses_sharding_manager: |
| | data = self.ulysses_sharding_manager.preprocess_data(data) |
| | output = self.ref_policy.compute_log_prob(data=data) |
| | output = DataProto.from_dict(tensors={"ref_log_probs": output}) |
| | output = self.ulysses_sharding_manager.postprocess_data(output) |
| |
|
| | |
| | |
| | if self.world_size > 1: |
| | self.fsdp_module._handle.reshard(True) |
| |
|
| | if self._use_param_offload: |
| | offload_fsdp_model(self.fsdp_module) |
| |
|
| | output = output.to("cpu") |
| | return output |
| |
|
| | @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) |
| | def compute_values(self, data: DataProto): |
| | assert self._is_critic |
| | data = data.to(torch.cuda.current_device()) |
| | if self._use_param_offload: |
| | load_fsdp_model(self.fsdp_module) |
| |
|
| | with self.ulysses_sharding_manager: |
| | data = self.ulysses_sharding_manager.preprocess_data(data=data) |
| | values = self.critic.compute_values(data=data) |
| | output = DataProto.from_dict(tensors={"values": values}) |
| | output = self.ulysses_sharding_manager.postprocess_data(data=output) |
| |
|
| | if self._use_param_offload: |
| | offload_fsdp_model(self.fsdp_module) |
| |
|
| | output = output.to("cpu") |
| | return output |
| |
|
| | @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) |
| | def update_critic(self, data: DataProto): |
| | data = data.to(torch.cuda.current_device()) |
| | if self._use_param_offload: |
| | load_fsdp_model(self.fsdp_module) |
| |
|
| | if self._use_optimizer_offload: |
| | load_fsdp_optimizer(optimizer=self.optimizer) |
| |
|
| | with self.ulysses_sharding_manager: |
| | data = self.ulysses_sharding_manager.preprocess_data(data=data) |
| | with Timer(name="update_critic", logger=None) as timer: |
| | metrics = self.critic.update_critic(data=data) |
| |
|
| | delta_time = timer.last |
| | global_num_tokens = data.meta_info["global_token_num"] |
| | estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) |
| | metrics["perf/mfu_critic"] = ( |
| | estimated_flops * self.config.actor.ppo_epochs / (promised_flops * self.world_size) |
| | ) |
| |
|
| | self.lr_scheduler.step() |
| | lr = self.lr_scheduler.get_last_lr()[0] |
| | metrics["critic/lr"] = lr |
| |
|
| | |
| | output = DataProto( |
| | non_tensor_batch={ |
| | metric: np.array([value] if np.isscalar(value) else value) for metric, value in metrics.items() |
| | } |
| | ) |
| |
|
| | if self._use_param_offload: |
| | offload_fsdp_model(self.fsdp_module) |
| |
|
| | if self._use_optimizer_offload: |
| | offload_fsdp_optimizer(optimizer=self.optimizer) |
| |
|
| | output = output.to("cpu") |
| | return output |
| |
|