Spaces:
Running
on
Zero
Running
on
Zero
import random | |
import torch | |
from torch import nn | |
import numpy as np | |
import re | |
import urllib.parse as ul | |
from bs4 import BeautifulSoup | |
from einops import rearrange | |
from dataclasses import dataclass | |
from torchvision import transforms | |
from diffusers.models.modeling_utils import ModelMixin | |
from transformers import AutoImageProcessor, AutoModel | |
from transformers import T5EncoderModel, T5Tokenizer, AutoTokenizer | |
from transformers.utils import ModelOutput | |
from typing import Iterable, Optional, Union, List | |
import step1x3d_geometry | |
from step1x3d_geometry.utils.typing import * | |
from .base import BaseCaptionEncoder | |
bad_punct_regex = re.compile( | |
r"[" | |
+ "#®•©™&@·º½¾¿¡§~" | |
+ "\)" | |
+ "\(" | |
+ "\]" | |
+ "\[" | |
+ "\}" | |
+ "\{" | |
+ "\|" | |
+ "\\" | |
+ "\/" | |
+ "\*" | |
+ r"]{1,}" | |
) # noqa | |
class T5Encoder(BaseCaptionEncoder, ModelMixin): | |
class Config(BaseCaptionEncoder.Config): | |
pretrained_model_name_or_path: Optional[str] = ( | |
None # the pretrained model name or path for condition model | |
) | |
pretrained_t5_name_or_path: Optional[str] = ( | |
None # the pretrained model name or path for T5 | |
) | |
preprocessing_text: bool = False | |
text_max_length: int = 77 | |
t5_type: Optional[str] = None | |
cfg: Config | |
def configure(self) -> None: | |
super().configure() | |
# Load the T5 model and tokenizer | |
if self.cfg.pretrained_t5_name_or_path is not None: | |
self.cfg.t5_type = f"google-t5/{self.cfg.pretrained_t5_name_or_path.split('google-t5--')[-1].split('/')[0]}" | |
self.tokenizer = T5Tokenizer.from_pretrained( | |
self.cfg.pretrained_t5_name_or_path | |
) | |
self.text_model = T5EncoderModel.from_pretrained( | |
self.cfg.pretrained_t5_name_or_path, torch_dtype=torch.bfloat16 | |
) | |
else: | |
if ( | |
self.cfg.pretrained_model_name_or_path is None | |
): # default to load t5-base model | |
assert self.cfg.t5_type is not None, "The t5_type should be provided" | |
print(f"Loading T5 model from {self.cfg.t5_type}") | |
self.text_model = T5EncoderModel( | |
config=T5EncoderModel.config_class.from_pretrained( | |
self.cfg.t5_type, | |
) | |
).to(torch.bfloat16) | |
elif "t5small" in self.cfg.pretrained_model_name_or_path: | |
print("Loading Dinov2 model from google-t5/t5-small") | |
self.cfg.t5_type = "google-t5/t5-small" | |
self.text_model = T5EncoderModel.from_pretrained( | |
self.cfg.t5_type, torch_dtype=torch.bfloat16 | |
) | |
elif "t5base" in self.cfg.pretrained_model_name_or_path: | |
print("Loading Dinov2 model from google-t5/t5-base") | |
self.cfg.t5_type = "google-t5/t5-base" | |
self.text_model = T5EncoderModel.from_pretrained( | |
self.cfg.t5_type, torch_dtype=torch.bfloat16 | |
) | |
else: | |
raise ValueError( | |
f"Unknown T5 model: {self.cfg.pretrained_model_name_or_path}" | |
) | |
self.tokenizer = T5Tokenizer.from_pretrained(self.cfg.t5_type) | |
# Set the empty image/text embeds | |
if self.cfg.zero_uncond_embeds: | |
self.empty_text_embeds = torch.zeros( | |
(1, self.cfg.text_max_length, self.text_model.config.hidden_size) | |
).detach() | |
else: | |
self.empty_text_embeds = self.encode_text([""]).detach() | |
# load pretrained_model_name_or_path | |
if self.cfg.pretrained_model_name_or_path is not None: | |
print(f"Loading ckpt from {self.cfg.pretrained_model_name_or_path}") | |
ckpt = torch.load( | |
self.cfg.pretrained_model_name_or_path, map_location="cpu" | |
)["state_dict"] | |
pretrained_model_ckpt = {} | |
for k, v in ckpt.items(): | |
if k.startswith("caption_condition."): | |
pretrained_model_ckpt[k.replace("caption_condition.", "")] = v | |
self.load_state_dict(pretrained_model_ckpt, strict=True) | |
def clean_caption(self, caption): | |
caption = str(caption) | |
caption = ul.unquote_plus(caption) | |
caption = caption.strip().lower() | |
caption = re.sub("<person>", "person", caption) | |
# urls: | |
caption = re.sub( | |
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa | |
"", | |
caption, | |
) # regex for urls | |
caption = re.sub( | |
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa | |
"", | |
caption, | |
) # regex for urls | |
# html: | |
caption = BeautifulSoup(caption, features="html.parser").text | |
# @<nickname> | |
caption = re.sub(r"@[\w\d]+\b", "", caption) | |
# 31C0—31EF CJK Strokes | |
# 31F0—31FF Katakana Phonetic Extensions | |
# 3200—32FF Enclosed CJK Letters and Months | |
# 3300—33FF CJK Compatibility | |
# 3400—4DBF CJK Unified Ideographs Extension A | |
# 4DC0—4DFF Yijing Hexagram Symbols | |
# 4E00—9FFF CJK Unified Ideographs | |
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) | |
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) | |
caption = re.sub(r"[\u3200-\u32ff]+", "", caption) | |
caption = re.sub(r"[\u3300-\u33ff]+", "", caption) | |
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) | |
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) | |
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) | |
####################################################### | |
# все виды тире / all types of dash --> "-" | |
caption = re.sub( | |
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa | |
"-", | |
caption, | |
) | |
# кавычки к одному стандарту | |
caption = re.sub(r"[`´«»“”¨]", '"', caption) | |
caption = re.sub(r"[‘’]", "'", caption) | |
# " | |
caption = re.sub(r""?", "", caption) | |
# & | |
caption = re.sub(r"&", "", caption) | |
# ip adresses: | |
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) | |
# article ids: | |
caption = re.sub(r"\d:\d\d\s+$", "", caption) | |
# \n | |
caption = re.sub(r"\\n", " ", caption) | |
# "#123" | |
caption = re.sub(r"#\d{1,3}\b", "", caption) | |
# "#12345.." | |
caption = re.sub(r"#\d{5,}\b", "", caption) | |
# "123456.." | |
caption = re.sub(r"\b\d{6,}\b", "", caption) | |
# filenames: | |
caption = re.sub( | |
r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption | |
) | |
# | |
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" | |
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" | |
caption = re.sub( | |
bad_punct_regex, r" ", caption | |
) # ***AUSVERKAUFT***, #AUSVERKAUFT | |
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " | |
# this-is-my-cute-cat / this_is_my_cute_cat | |
regex2 = re.compile(r"(?:\-|\_)") | |
if len(re.findall(regex2, caption)) > 3: | |
caption = re.sub(regex2, " ", caption) | |
caption = self.basic_clean(caption) | |
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 | |
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc | |
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 | |
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) | |
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) | |
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) | |
caption = re.sub( | |
r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption | |
) | |
caption = re.sub(r"\bpage\s+\d+\b", "", caption) | |
caption = re.sub( | |
r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption | |
) # j2d1a2a... | |
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) | |
caption = re.sub(r"\b\s+\:\s+", r": ", caption) | |
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) | |
caption = re.sub(r"\s+", " ", caption) | |
caption.strip() | |
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) | |
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) | |
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) | |
caption = re.sub(r"^\.\S+$", "", caption) | |
return caption.strip() | |
def text_preprocessing(self, text): | |
if self.cfg.preprocessing_text: | |
# The exact text cleaning as was in the training stage: | |
text = self.clean_caption(text) | |
return text | |
else: | |
return text.lower().strip() | |
def encode_text(self, texts: List[str]) -> torch.FloatTensor: | |
texts = [self.text_preprocessing(text) for text in texts] | |
text_tokens_and_mask = self.tokenizer( | |
texts, | |
max_length=self.cfg.text_max_length, | |
padding="max_length", | |
truncation=True, | |
return_attention_mask=True, | |
add_special_tokens=True, | |
return_tensors="pt", | |
) | |
text_tokens_and_mask["input_ids"] = text_tokens_and_mask["input_ids"] # N x 77 | |
text_tokens_and_mask["attention_mask"] = text_tokens_and_mask["attention_mask"] | |
with torch.no_grad(): | |
label_embeds = self.text_model( | |
input_ids=text_tokens_and_mask["input_ids"].to(self.text_model.device), | |
attention_mask=text_tokens_and_mask["attention_mask"].to( | |
self.text_model.device | |
), | |
)["last_hidden_state"].detach() | |
return label_embeds | |