Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| try: | |
| import torch_npu | |
| from torch_npu.contrib import transfer_to_npu | |
| DEVICE_TYPE = "npu" | |
| except ModuleNotFoundError: | |
| DEVICE_TYPE = "cuda" | |
| from .text_encoder import T5TextEncoder | |
| class SketchT5TextEncoder(T5TextEncoder): | |
| def __init__( | |
| self, f0_dim: int , energy_dim: int, latent_dim: int, | |
| embed_dim: int, model_name: str = "google/flan-t5-large", | |
| ): | |
| super().__init__( | |
| embed_dim = embed_dim, | |
| model_name = model_name, | |
| ) | |
| self.f0_proj = nn.Linear(f0_dim, latent_dim) | |
| self.f0_norm = nn.LayerNorm(f0_dim) | |
| self.energy_proj = nn.Linear(energy_dim, latent_dim) | |
| def encode( | |
| self, | |
| text: list[str], | |
| ): | |
| with torch.no_grad(), torch.amp.autocast( | |
| device_type=DEVICE_TYPE, enabled=False | |
| ): | |
| return super().encode(text) | |
| def encode_sketch( | |
| self, | |
| f0, | |
| energy, | |
| ): | |
| f0_embed = self.f0_proj(self.f0_norm(f0)).unsqueeze(-1) | |
| energy_embed = self.energy_proj(energy).unsqueeze(-1) | |
| sketch_embed = torch.cat([f0_embed, energy_embed], dim=-1) | |
| return {"output": sketch_embed} | |
| if __name__ == "__main__": | |
| text_encoder = T5TextEncoder(embed_dim=512) | |
| text = ["a man is speaking", "a woman is singing while a dog is barking"] | |
| output = text_encoder(text) | |