from torch import Tensor, nn class NewTokenEmb(nn.Module): """ For adding new tokens to a pretrained model """ def __init__(self, old_embeddings: nn.Embedding, new_num_tokens: int = None) -> None: super().__init__() self.num_tokens = old_embeddings.num_embeddings + new_num_tokens self.old_num_tokens = old_embeddings.num_embeddings self.new_num_tokens = new_num_tokens self.embedding_dim = old_embeddings.embedding_dim # For text embeddings self.text_embeddings = nn.Embedding( self.num_tokens, self.embedding_dim, device=old_embeddings.weight.device, dtype=old_embeddings.weight.dtype) with torch.no_grad(): self.text_embeddings.weight.data[:old_embeddings. num_embeddings] = old_embeddings.weight.data self.text_embeddings.weight.data[ self.old_num_tokens:] = torch.zeros( self.new_num_tokens, self.embedding_dim, dtype=old_embeddings.weight.dtype, device=old_embeddings.weight.device) self.text_embeddings.weight.requires_grad_(False) # For motion embeddings self.motion_embeddings = nn.Embedding( new_num_tokens, self.embedding_dim, device=old_embeddings.weight.device, dtype=old_embeddings.weight.dtype) with torch.no_grad(): self.motion_embeddings.weight.data[:self. old_num_tokens] = torch.zeros( new_num_tokens, self.embedding_dim, dtype=old_embeddings.weight. dtype, device=old_embeddings. weight.device) self.word2motionProj = nn.Linear(self.old_num_tokens, new_num_tokens) def forward(self, input: Tensor) -> Tensor: with torch.no_grad(): self.motion_embeddings.weight.data[:self. old_num_tokens] = torch.zeros( self.new_num_tokens, self.embedding_dim, dtype=self.motion_embeddings .weight.dtype, device=self. motion_embeddings.weight. device) self.motion_embeddings.weight.data[ self.old_num_tokens:] = self.word2motionProj( self.text_embeddings.weight.data[:self.old_num_tokens].permute( 1, 0)).permute(1, 0) return self.text_embeddings(input) + self.motion_embeddings(input)