JiT-AnimeFace-Demo / model /class_encoder.py
Plat
init
4b08319
import warnings
from typing import NamedTuple
import torch
import torch.nn as nn
PromptType = str | list[str]
class ClassTokenizerOutput(NamedTuple):
class_ids: torch.Tensor
attention_mask: torch.Tensor
class ClassTokenizer:
def __init__(
self,
label2id: dict[str, int],
splitter: str = " ",
) -> None:
self.label2id = label2id
self.id2label = {v: k for k, v in label2id.items()}
self.splitter = splitter
self.pad_token_id = len(label2id)
assert all([id < len(label2id) for id in label2id.values()]), (
"All label IDs must be less than the number of classes."
)
def normalize_prompts(
self,
class_names: PromptType,
) -> list[str]:
_class_names: list[str] = (
class_names if isinstance(class_names, list) else [class_names]
)
return _class_names
def tokenize(
self,
prompts: PromptType,
max_length: int = 32,
) -> ClassTokenizerOutput:
# 1. Normalize class names
_prompts = self.normalize_prompts(prompts)
# 2. Convert to IDs
class_ids = []
masks = []
for text in _prompts:
ids = []
for label in text.split(self.splitter):
if label.strip() == "":
continue
id = self.label2id.get(label.strip())
if id is not None: # 0 is OK
ids.append(id)
masks.append(1)
else:
warnings.warn(f"Label '{label}' not found in label2id mapping.")
class_ids.append(ids)
# 3. Pad to max_length
padded_class_ids = []
padded_masks = []
for _i, ids in enumerate(class_ids):
if len(ids) < max_length:
mask = [1] * len(ids) + [0] * (max_length - len(ids))
ids = ids + [self.pad_token_id] * (max_length - len(ids)) # padding idx
else:
mask = [1] * max_length
ids = ids[:max_length]
padded_class_ids.append(ids)
padded_masks.append(mask)
return ClassTokenizerOutput(
class_ids=torch.tensor(padded_class_ids, dtype=torch.long),
attention_mask=torch.tensor(padded_masks, dtype=torch.long),
)
class ClassEncoderOutput(NamedTuple):
embeddings: torch.Tensor
attention_mask: torch.Tensor
class ClassEncoder(nn.Module):
def __init__(
self,
label2id: dict[str, int],
embedding_dim: int,
):
super().__init__()
self.num_classes = len(label2id)
self.pad_token_id = self.num_classes # padding idx
self.embedding = nn.Embedding(
self.num_classes + 1, # +1 for padding idx
embedding_dim,
padding_idx=self.num_classes,
)
self.tokenizer = ClassTokenizer(label2id)
def initialize_weights(self):
nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)
def encode_prompts(
self,
prompts: PromptType,
max_token_length: int = 32,
):
# 1. Tokenize prompts
class_ids, attention_mask = self.tokenizer.tokenize(
prompts,
max_length=max_token_length,
)
# 3. Get embeddings
embeddings = self.embedding(class_ids.to(self.embedding.weight.device))
return ClassEncoderOutput(
embeddings=embeddings,
attention_mask=attention_mask,
)