Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| import einops | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.init as init | |
| class PoseNet(nn.Module): | |
| """a tiny conv network for introducing pose sequence as the condition | |
| """ | |
| def __init__(self, noise_latent_channels=320, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| # multiple convolution layers | |
| self.conv_layers = nn.Sequential( | |
| nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1), | |
| nn.SiLU(), | |
| nn.Conv2d(in_channels=3, out_channels=16, kernel_size=4, stride=2, padding=1), | |
| nn.SiLU(), | |
| nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1), | |
| nn.SiLU(), | |
| nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2, padding=1), | |
| nn.SiLU(), | |
| nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1), | |
| nn.SiLU(), | |
| nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1), | |
| nn.SiLU(), | |
| nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1), | |
| nn.SiLU(), | |
| nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), | |
| nn.SiLU() | |
| ) | |
| # Final projection layer | |
| self.final_proj = nn.Conv2d(in_channels=128, out_channels=noise_latent_channels, kernel_size=1) | |
| # Initialize layers | |
| self._initialize_weights() | |
| self.scale = nn.Parameter(torch.ones(1) * 2) | |
| def _initialize_weights(self): | |
| """Initialize weights with He. initialization and zero out the biases | |
| """ | |
| for m in self.conv_layers: | |
| if isinstance(m, nn.Conv2d): | |
| n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels | |
| init.normal_(m.weight, mean=0.0, std=np.sqrt(2. / n)) | |
| if m.bias is not None: | |
| init.zeros_(m.bias) | |
| init.zeros_(self.final_proj.weight) | |
| if self.final_proj.bias is not None: | |
| init.zeros_(self.final_proj.bias) | |
| def forward(self, x): | |
| if x.ndim == 5: | |
| x = einops.rearrange(x, "b f c h w -> (b f) c h w") | |
| x = self.conv_layers(x) | |
| x = self.final_proj(x) | |
| return x * self.scale | |
| def from_pretrained(cls, pretrained_model_path): | |
| """load pretrained pose-net weights | |
| """ | |
| if not Path(pretrained_model_path).exists(): | |
| print(f"There is no model file in {pretrained_model_path}") | |
| print(f"loaded PoseNet's pretrained weights from {pretrained_model_path}.") | |
| state_dict = torch.load(pretrained_model_path, map_location="cpu") | |
| model = PoseNet(noise_latent_channels=320) | |
| model.load_state_dict(state_dict, strict=True) | |
| return model | |