| """ |
| Injects the style vector into the model via soft prompt conditioning. |
| The style vector is projected to the model's hidden dimension and |
| prepended to the input token embeddings as virtual tokens. |
| |
| This technique is called "prefix tuning" / "style prefix injection". |
| It biases the model's attention toward the desired output style |
| without modifying the base model weights. |
| |
| For Flan-T5: injects into encoder input embeddings |
| For BART: injects into encoder input embeddings |
| For Llama: prepends to the full input context |
| """ |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| class StyleConditioner(nn.Module): |
| """ |
| Projects a 512-dim style vector to n_prefix_tokens virtual tokens |
| in the model's embedding space. |
| """ |
|
|
| def __init__( |
| self, |
| style_dim: int = 512, |
| model_hidden_dim: int = 512, |
| n_prefix_tokens: int = 10, |
| ): |
| super().__init__() |
| self.style_dim = style_dim |
| self.model_hidden_dim = model_hidden_dim |
| self.n_prefix_tokens = n_prefix_tokens |
|
|
| |
| |
| total_output_dim = n_prefix_tokens * model_hidden_dim |
| self.projection = nn.Sequential( |
| nn.Linear(style_dim, total_output_dim), |
| nn.Tanh(), |
| ) |
|
|
| def forward(self, style_vector: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| style_vector: [batch_size, 512] |
| Returns: |
| prefix_embeddings: [batch_size, n_prefix_tokens, model_hidden_dim] |
| """ |
| |
| projected = self.projection(style_vector) |
|
|
| |
| batch_size = style_vector.size(0) |
| prefix_embeddings = projected.view(batch_size, self.n_prefix_tokens, self.model_hidden_dim) |
|
|
| return prefix_embeddings |
|
|
|
|
| def prepend_style_prefix( |
| input_embeddings: torch.Tensor, |
| style_prefix: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Concatenates style prefix to input embeddings along sequence dimension. |
| |
| Args: |
| input_embeddings: [batch, seq_len, hidden_dim] |
| style_prefix: [batch, n_prefix, hidden_dim] |
| Returns: |
| [batch, n_prefix + seq_len, hidden_dim] |
| """ |
| return torch.cat([style_prefix, input_embeddings], dim=1) |
|
|