mbellan's picture
Initial deployment
c3efd49
"""Policy wrapper for making voice models RL-compatible."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
import logging
logger = logging.getLogger(__name__)
class PolicyValueHead(nn.Module):
"""
Policy and value head for RL training on voice models.
Adds a policy head (for action log probabilities) and value head
(for state value estimation) on top of a voice model's hidden states.
"""
def __init__(
self,
hidden_size: int,
action_dim: int = 256,
value_hidden_size: int = 128
):
"""
Initialize policy and value heads.
Args:
hidden_size: Size of the base model's hidden states
action_dim: Dimensionality of the action space
value_hidden_size: Hidden size for value network
"""
super().__init__()
# Policy head - outputs action logits
self.policy_head = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size // 2, action_dim)
)
# Value head - outputs state value estimate
self.value_head = nn.Sequential(
nn.Linear(hidden_size, value_hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(value_hidden_size, 1)
)
logger.info(f"Initialized PolicyValueHead with hidden_size={hidden_size}, action_dim={action_dim}")
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass through policy and value heads.
Args:
hidden_states: Hidden states from base model [batch, seq_len, hidden_size]
Returns:
Tuple of (action_logits, state_values)
"""
# Pool hidden states (mean pooling over sequence)
pooled = hidden_states.mean(dim=1) # [batch, hidden_size]
# Get action logits and values
action_logits = self.policy_head(pooled) # [batch, action_dim]
state_values = self.value_head(pooled) # [batch, 1]
return action_logits, state_values
class RLVoiceModel(nn.Module):
"""
RL-compatible wrapper for voice models.
Wraps a HuggingFace voice model and adds policy/value heads
for reinforcement learning training.
"""
def __init__(
self,
base_model: nn.Module,
hidden_size: int,
action_dim: int = 256,
action_representation: str = "discrete"
):
"""
Initialize RL voice model wrapper.
Args:
base_model: Base voice model (e.g., wav2vec2)
hidden_size: Hidden size of base model
action_dim: Dimensionality of action space
action_representation: "discrete" or "continuous"
"""
super().__init__()
self.base_model = base_model
self.hidden_size = hidden_size
self.action_dim = action_dim
self.action_representation = action_representation
# Add policy and value heads
self.policy_value_head = PolicyValueHead(
hidden_size=hidden_size,
action_dim=action_dim
)
logger.info(f"Initialized RLVoiceModel with action_representation={action_representation}")
def forward(
self,
input_features: torch.Tensor,
return_hidden_states: bool = False,
**kwargs
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Forward pass for RL training.
Args:
input_features: Input audio features [batch, seq_len, features]
return_hidden_states: Whether to return base model hidden states
**kwargs: Additional arguments for base model
Returns:
Tuple of (log_probs, values, hidden_states)
"""
# Get base model outputs
base_outputs = self.base_model(input_features, **kwargs)
# Extract hidden states
if hasattr(base_outputs, 'last_hidden_state'):
hidden_states = base_outputs.last_hidden_state
elif isinstance(base_outputs, torch.Tensor):
hidden_states = base_outputs
else:
hidden_states = base_outputs[0]
# Get policy and value outputs
action_logits, state_values = self.policy_value_head(hidden_states)
# Compute log probabilities
if self.action_representation == "discrete":
log_probs = F.log_softmax(action_logits, dim=-1)
else:
# For continuous actions, return the logits directly
log_probs = action_logits
if return_hidden_states:
return log_probs, state_values, hidden_states
else:
return log_probs, state_values, None
def sample_action(
self,
input_features: torch.Tensor,
deterministic: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sample actions from the policy.
Args:
input_features: Input audio features
deterministic: If True, take most likely action
Returns:
Tuple of (actions, log_probs, values)
"""
log_probs, values, _ = self.forward(input_features)
if self.action_representation == "discrete":
if deterministic:
actions = log_probs.argmax(dim=-1)
else:
# Sample from categorical distribution
probs = torch.exp(log_probs)
actions = torch.multinomial(probs, num_samples=1).squeeze(-1)
# Get log prob of selected actions
action_log_probs = log_probs.gather(-1, actions.unsqueeze(-1)).squeeze(-1)
else:
# For continuous actions, add noise for exploration
if deterministic:
actions = log_probs
else:
actions = log_probs + torch.randn_like(log_probs) * 0.1
action_log_probs = -0.5 * ((actions - log_probs) ** 2).sum(dim=-1)
return actions, action_log_probs, values
def evaluate_actions(
self,
input_features: torch.Tensor,
actions: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Evaluate actions (for PPO training).
Args:
input_features: Input audio features
actions: Actions to evaluate
Returns:
Tuple of (log_probs, values, entropy)
"""
log_probs, values, _ = self.forward(input_features)
if self.action_representation == "discrete":
# Get log probs of given actions
action_log_probs = log_probs.gather(-1, actions.unsqueeze(-1)).squeeze(-1)
# Compute entropy
probs = torch.exp(log_probs)
entropy = -(probs * log_probs).sum(dim=-1).mean()
else:
# For continuous actions
action_log_probs = -0.5 * ((actions - log_probs) ** 2).sum(dim=-1)
# Entropy for continuous (Gaussian assumption)
entropy = 0.5 * log_probs.shape[-1] * (1.0 + torch.log(torch.tensor(2.0 * 3.14159)))
return action_log_probs, values.squeeze(-1), entropy
def get_base_model(self) -> nn.Module:
"""Get the underlying base model."""
return self.base_model
def freeze_base_model(self) -> None:
"""Freeze base model parameters (only train policy/value heads)."""
for param in self.base_model.parameters():
param.requires_grad = False
logger.info("Froze base model parameters")
def unfreeze_base_model(self) -> None:
"""Unfreeze base model parameters."""
for param in self.base_model.parameters():
param.requires_grad = True
logger.info("Unfroze base model parameters")
class SequentialVoicePolicy(nn.Module):
"""
Sequential policy for frame-by-frame voice generation.
For autoregressive voice generation where each frame is an action.
"""
def __init__(
self,
base_model: nn.Module,
hidden_size: int,
frame_size: int = 80, # e.g., 80-dim mel spectrogram
max_seq_len: int = 1000
):
"""
Initialize sequential voice policy.
Args:
base_model: Base model for processing context
hidden_size: Hidden size
frame_size: Size of each output frame
max_seq_len: Maximum sequence length
"""
super().__init__()
self.base_model = base_model
self.hidden_size = hidden_size
self.frame_size = frame_size
self.max_seq_len = max_seq_len
# Frame generation network
self.frame_generator = nn.LSTM(
input_size=hidden_size + frame_size,
hidden_size=hidden_size,
num_layers=2,
batch_first=True
)
# Output projection
self.output_projection = nn.Linear(hidden_size, frame_size)
# Value network
self.value_net = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2),
nn.ReLU(),
nn.Linear(hidden_size // 2, 1)
)
logger.info(f"Initialized SequentialVoicePolicy with frame_size={frame_size}")
def forward(
self,
input_features: torch.Tensor,
previous_frames: Optional[torch.Tensor] = None,
num_frames: int = 10
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Generate sequence of frames.
Args:
input_features: Input conditioning features
previous_frames: Previous generated frames (for autoregression)
num_frames: Number of frames to generate
Returns:
Tuple of (generated_frames, log_probs, values)
"""
batch_size = input_features.shape[0]
# Get context from base model
base_outputs = self.base_model(input_features)
if hasattr(base_outputs, 'last_hidden_state'):
context = base_outputs.last_hidden_state.mean(dim=1) # [batch, hidden]
else:
context = base_outputs.mean(dim=1) if len(base_outputs.shape) > 2 else base_outputs
# Initialize
if previous_frames is None:
current_frame = torch.zeros(batch_size, self.frame_size, device=input_features.device)
else:
current_frame = previous_frames[:, -1]
hidden = None
generated_frames = []
log_probs = []
# Generate frames autoregressively
for t in range(num_frames):
# Combine context and previous frame
lstm_input = torch.cat([context, current_frame], dim=-1).unsqueeze(1)
# LSTM step
lstm_out, hidden = self.frame_generator(lstm_input, hidden)
# Project to frame
frame_logits = self.output_projection(lstm_out.squeeze(1))
# Sample frame (treat as continuous output)
current_frame = torch.tanh(frame_logits) # Bound to [-1, 1]
# Compute log prob (simplified)
frame_log_prob = -0.5 * (frame_logits ** 2).sum(dim=-1)
generated_frames.append(current_frame)
log_probs.append(frame_log_prob)
# Stack results
generated_frames = torch.stack(generated_frames, dim=1) # [batch, num_frames, frame_size]
log_probs = torch.stack(log_probs, dim=1) # [batch, num_frames]
# Compute values
values = self.value_net(context) # [batch, 1]
return generated_frames, log_probs, values