Spaces:
Runtime error
Runtime error
| """Policy wrapper for making voice models RL-compatible.""" | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Tuple, Optional | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class PolicyValueHead(nn.Module): | |
| """ | |
| Policy and value head for RL training on voice models. | |
| Adds a policy head (for action log probabilities) and value head | |
| (for state value estimation) on top of a voice model's hidden states. | |
| """ | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| action_dim: int = 256, | |
| value_hidden_size: int = 128 | |
| ): | |
| """ | |
| Initialize policy and value heads. | |
| Args: | |
| hidden_size: Size of the base model's hidden states | |
| action_dim: Dimensionality of the action space | |
| value_hidden_size: Hidden size for value network | |
| """ | |
| super().__init__() | |
| # Policy head - outputs action logits | |
| self.policy_head = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size // 2), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(hidden_size // 2, action_dim) | |
| ) | |
| # Value head - outputs state value estimate | |
| self.value_head = nn.Sequential( | |
| nn.Linear(hidden_size, value_hidden_size), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(value_hidden_size, 1) | |
| ) | |
| logger.info(f"Initialized PolicyValueHead with hidden_size={hidden_size}, action_dim={action_dim}") | |
| def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Forward pass through policy and value heads. | |
| Args: | |
| hidden_states: Hidden states from base model [batch, seq_len, hidden_size] | |
| Returns: | |
| Tuple of (action_logits, state_values) | |
| """ | |
| # Pool hidden states (mean pooling over sequence) | |
| pooled = hidden_states.mean(dim=1) # [batch, hidden_size] | |
| # Get action logits and values | |
| action_logits = self.policy_head(pooled) # [batch, action_dim] | |
| state_values = self.value_head(pooled) # [batch, 1] | |
| return action_logits, state_values | |
| class RLVoiceModel(nn.Module): | |
| """ | |
| RL-compatible wrapper for voice models. | |
| Wraps a HuggingFace voice model and adds policy/value heads | |
| for reinforcement learning training. | |
| """ | |
| def __init__( | |
| self, | |
| base_model: nn.Module, | |
| hidden_size: int, | |
| action_dim: int = 256, | |
| action_representation: str = "discrete" | |
| ): | |
| """ | |
| Initialize RL voice model wrapper. | |
| Args: | |
| base_model: Base voice model (e.g., wav2vec2) | |
| hidden_size: Hidden size of base model | |
| action_dim: Dimensionality of action space | |
| action_representation: "discrete" or "continuous" | |
| """ | |
| super().__init__() | |
| self.base_model = base_model | |
| self.hidden_size = hidden_size | |
| self.action_dim = action_dim | |
| self.action_representation = action_representation | |
| # Add policy and value heads | |
| self.policy_value_head = PolicyValueHead( | |
| hidden_size=hidden_size, | |
| action_dim=action_dim | |
| ) | |
| logger.info(f"Initialized RLVoiceModel with action_representation={action_representation}") | |
| def forward( | |
| self, | |
| input_features: torch.Tensor, | |
| return_hidden_states: bool = False, | |
| **kwargs | |
| ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | |
| """ | |
| Forward pass for RL training. | |
| Args: | |
| input_features: Input audio features [batch, seq_len, features] | |
| return_hidden_states: Whether to return base model hidden states | |
| **kwargs: Additional arguments for base model | |
| Returns: | |
| Tuple of (log_probs, values, hidden_states) | |
| """ | |
| # Get base model outputs | |
| base_outputs = self.base_model(input_features, **kwargs) | |
| # Extract hidden states | |
| if hasattr(base_outputs, 'last_hidden_state'): | |
| hidden_states = base_outputs.last_hidden_state | |
| elif isinstance(base_outputs, torch.Tensor): | |
| hidden_states = base_outputs | |
| else: | |
| hidden_states = base_outputs[0] | |
| # Get policy and value outputs | |
| action_logits, state_values = self.policy_value_head(hidden_states) | |
| # Compute log probabilities | |
| if self.action_representation == "discrete": | |
| log_probs = F.log_softmax(action_logits, dim=-1) | |
| else: | |
| # For continuous actions, return the logits directly | |
| log_probs = action_logits | |
| if return_hidden_states: | |
| return log_probs, state_values, hidden_states | |
| else: | |
| return log_probs, state_values, None | |
| def sample_action( | |
| self, | |
| input_features: torch.Tensor, | |
| deterministic: bool = False | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Sample actions from the policy. | |
| Args: | |
| input_features: Input audio features | |
| deterministic: If True, take most likely action | |
| Returns: | |
| Tuple of (actions, log_probs, values) | |
| """ | |
| log_probs, values, _ = self.forward(input_features) | |
| if self.action_representation == "discrete": | |
| if deterministic: | |
| actions = log_probs.argmax(dim=-1) | |
| else: | |
| # Sample from categorical distribution | |
| probs = torch.exp(log_probs) | |
| actions = torch.multinomial(probs, num_samples=1).squeeze(-1) | |
| # Get log prob of selected actions | |
| action_log_probs = log_probs.gather(-1, actions.unsqueeze(-1)).squeeze(-1) | |
| else: | |
| # For continuous actions, add noise for exploration | |
| if deterministic: | |
| actions = log_probs | |
| else: | |
| actions = log_probs + torch.randn_like(log_probs) * 0.1 | |
| action_log_probs = -0.5 * ((actions - log_probs) ** 2).sum(dim=-1) | |
| return actions, action_log_probs, values | |
| def evaluate_actions( | |
| self, | |
| input_features: torch.Tensor, | |
| actions: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Evaluate actions (for PPO training). | |
| Args: | |
| input_features: Input audio features | |
| actions: Actions to evaluate | |
| Returns: | |
| Tuple of (log_probs, values, entropy) | |
| """ | |
| log_probs, values, _ = self.forward(input_features) | |
| if self.action_representation == "discrete": | |
| # Get log probs of given actions | |
| action_log_probs = log_probs.gather(-1, actions.unsqueeze(-1)).squeeze(-1) | |
| # Compute entropy | |
| probs = torch.exp(log_probs) | |
| entropy = -(probs * log_probs).sum(dim=-1).mean() | |
| else: | |
| # For continuous actions | |
| action_log_probs = -0.5 * ((actions - log_probs) ** 2).sum(dim=-1) | |
| # Entropy for continuous (Gaussian assumption) | |
| entropy = 0.5 * log_probs.shape[-1] * (1.0 + torch.log(torch.tensor(2.0 * 3.14159))) | |
| return action_log_probs, values.squeeze(-1), entropy | |
| def get_base_model(self) -> nn.Module: | |
| """Get the underlying base model.""" | |
| return self.base_model | |
| def freeze_base_model(self) -> None: | |
| """Freeze base model parameters (only train policy/value heads).""" | |
| for param in self.base_model.parameters(): | |
| param.requires_grad = False | |
| logger.info("Froze base model parameters") | |
| def unfreeze_base_model(self) -> None: | |
| """Unfreeze base model parameters.""" | |
| for param in self.base_model.parameters(): | |
| param.requires_grad = True | |
| logger.info("Unfroze base model parameters") | |
| class SequentialVoicePolicy(nn.Module): | |
| """ | |
| Sequential policy for frame-by-frame voice generation. | |
| For autoregressive voice generation where each frame is an action. | |
| """ | |
| def __init__( | |
| self, | |
| base_model: nn.Module, | |
| hidden_size: int, | |
| frame_size: int = 80, # e.g., 80-dim mel spectrogram | |
| max_seq_len: int = 1000 | |
| ): | |
| """ | |
| Initialize sequential voice policy. | |
| Args: | |
| base_model: Base model for processing context | |
| hidden_size: Hidden size | |
| frame_size: Size of each output frame | |
| max_seq_len: Maximum sequence length | |
| """ | |
| super().__init__() | |
| self.base_model = base_model | |
| self.hidden_size = hidden_size | |
| self.frame_size = frame_size | |
| self.max_seq_len = max_seq_len | |
| # Frame generation network | |
| self.frame_generator = nn.LSTM( | |
| input_size=hidden_size + frame_size, | |
| hidden_size=hidden_size, | |
| num_layers=2, | |
| batch_first=True | |
| ) | |
| # Output projection | |
| self.output_projection = nn.Linear(hidden_size, frame_size) | |
| # Value network | |
| self.value_net = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size // 2), | |
| nn.ReLU(), | |
| nn.Linear(hidden_size // 2, 1) | |
| ) | |
| logger.info(f"Initialized SequentialVoicePolicy with frame_size={frame_size}") | |
| def forward( | |
| self, | |
| input_features: torch.Tensor, | |
| previous_frames: Optional[torch.Tensor] = None, | |
| num_frames: int = 10 | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Generate sequence of frames. | |
| Args: | |
| input_features: Input conditioning features | |
| previous_frames: Previous generated frames (for autoregression) | |
| num_frames: Number of frames to generate | |
| Returns: | |
| Tuple of (generated_frames, log_probs, values) | |
| """ | |
| batch_size = input_features.shape[0] | |
| # Get context from base model | |
| base_outputs = self.base_model(input_features) | |
| if hasattr(base_outputs, 'last_hidden_state'): | |
| context = base_outputs.last_hidden_state.mean(dim=1) # [batch, hidden] | |
| else: | |
| context = base_outputs.mean(dim=1) if len(base_outputs.shape) > 2 else base_outputs | |
| # Initialize | |
| if previous_frames is None: | |
| current_frame = torch.zeros(batch_size, self.frame_size, device=input_features.device) | |
| else: | |
| current_frame = previous_frames[:, -1] | |
| hidden = None | |
| generated_frames = [] | |
| log_probs = [] | |
| # Generate frames autoregressively | |
| for t in range(num_frames): | |
| # Combine context and previous frame | |
| lstm_input = torch.cat([context, current_frame], dim=-1).unsqueeze(1) | |
| # LSTM step | |
| lstm_out, hidden = self.frame_generator(lstm_input, hidden) | |
| # Project to frame | |
| frame_logits = self.output_projection(lstm_out.squeeze(1)) | |
| # Sample frame (treat as continuous output) | |
| current_frame = torch.tanh(frame_logits) # Bound to [-1, 1] | |
| # Compute log prob (simplified) | |
| frame_log_prob = -0.5 * (frame_logits ** 2).sum(dim=-1) | |
| generated_frames.append(current_frame) | |
| log_probs.append(frame_log_prob) | |
| # Stack results | |
| generated_frames = torch.stack(generated_frames, dim=1) # [batch, num_frames, frame_size] | |
| log_probs = torch.stack(log_probs, dim=1) # [batch, num_frames] | |
| # Compute values | |
| values = self.value_net(context) # [batch, 1] | |
| return generated_frames, log_probs, values | |