ACE-Plus / modules /embedder.py
chaojiemao's picture
modify ace plus
d1a539d
raw
history blame
8.23 kB
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
# This file contains code that is adapted from
# https://github.com/black-forest-labs/flux.git
import warnings
import torch
import torch.utils.dlpack
import transformers
from scepter.modules.model.embedder.base_embedder import BaseEmbedder
from scepter.modules.model.registry import EMBEDDERS
from scepter.modules.model.tokenizer.tokenizer_component import (
basic_clean, canonicalize, whitespace_clean)
from scepter.modules.utils.config import dict_to_yaml
from scepter.modules.utils.file_system import FS
try:
from transformers import AutoTokenizer, T5EncoderModel
except Exception as e:
warnings.warn(
f'Import transformers error, please deal with this problem: {e}')
@EMBEDDERS.register_class()
class ACEHFEmbedder(BaseEmbedder):
para_dict = {
"HF_MODEL_CLS": {
"value": None,
"description": "huggingface cls in transfomer"
},
"MODEL_PATH": {
"value": None,
"description": "model folder path"
},
"HF_TOKENIZER_CLS": {
"value": None,
"description": "huggingface cls in transfomer"
},
"TOKENIZER_PATH": {
"value": None,
"description": "tokenizer folder path"
},
"MAX_LENGTH": {
"value": 77,
"description": "max length of input"
},
"OUTPUT_KEY": {
"value": "last_hidden_state",
"description": "output key"
},
"D_TYPE": {
"value": "float",
"description": "dtype"
},
"BATCH_INFER": {
"value": False,
"description": "batch infer"
}
}
para_dict.update(BaseEmbedder.para_dict)
def __init__(self, cfg, logger=None):
super().__init__(cfg, logger=logger)
hf_model_cls = cfg.get('HF_MODEL_CLS', None)
model_path = cfg.get("MODEL_PATH", None)
hf_tokenizer_cls = cfg.get('HF_TOKENIZER_CLS', None)
tokenizer_path = cfg.get('TOKENIZER_PATH', None)
self.max_length = cfg.get('MAX_LENGTH', 77)
self.output_key = cfg.get("OUTPUT_KEY", "last_hidden_state")
self.d_type = cfg.get("D_TYPE", "float")
self.clean = cfg.get("CLEAN", "whitespace")
self.batch_infer = cfg.get("BATCH_INFER", False)
self.added_identifier = cfg.get('ADDED_IDENTIFIER', None)
torch_dtype = getattr(torch, self.d_type)
assert hf_model_cls is not None and hf_tokenizer_cls is not None
assert model_path is not None and tokenizer_path is not None
with FS.get_dir_to_local_dir(tokenizer_path, wait_finish=True) as local_path:
self.tokenizer = getattr(transformers, hf_tokenizer_cls).from_pretrained(local_path,
max_length = self.max_length,
torch_dtype = torch_dtype,
additional_special_tokens=self.added_identifier)
with FS.get_dir_to_local_dir(model_path, wait_finish=True) as local_path:
self.hf_module = getattr(transformers, hf_model_cls).from_pretrained(local_path, torch_dtype = torch_dtype)
self.hf_module = self.hf_module.eval().requires_grad_(False)
def forward(self, text: list[str], return_mask = False):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=False,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
outputs = self.hf_module(
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
attention_mask=None,
output_hidden_states=False,
)
if return_mask:
return outputs[self.output_key], batch_encoding['attention_mask'].to(self.hf_module.device)
else:
return outputs[self.output_key], None
def encode(self, text, return_mask = False):
if isinstance(text, str):
text = [text]
if self.clean:
text = [self._clean(u) for u in text]
if not self.batch_infer:
cont, mask = [], []
for tt in text:
one_cont, one_mask = self([tt], return_mask=return_mask)
cont.append(one_cont)
mask.append(one_mask)
if return_mask:
return torch.cat(cont, dim=0), torch.cat(mask, dim=0)
else:
return torch.cat(cont, dim=0)
else:
ret_data = self(text, return_mask = return_mask)
if return_mask:
return ret_data
else:
return ret_data[0]
def encode_list(self, text_list, return_mask=True):
cont_list = []
mask_list = []
for pp in text_list:
cont = self.encode(pp, return_mask=return_mask)
cont_list.append(cont[0]) if return_mask else cont_list.append(cont)
mask_list.append(cont[1]) if return_mask else mask_list.append(None)
if return_mask:
return cont_list, mask_list
else:
return cont_list
def encode_list_of_list(self, text_list, return_mask=True):
cont_list = []
mask_list = []
for pp in text_list:
cont = self.encode_list(pp, return_mask=return_mask)
cont_list.append(cont[0]) if return_mask else cont_list.append(cont)
mask_list.append(cont[1]) if return_mask else mask_list.append(None)
if return_mask:
return cont_list, mask_list
else:
return cont_list
def _clean(self, text):
if self.clean == 'whitespace':
text = whitespace_clean(basic_clean(text))
elif self.clean == 'lower':
text = whitespace_clean(basic_clean(text)).lower()
elif self.clean == 'canonicalize':
text = canonicalize(basic_clean(text))
return text
@staticmethod
def get_config_template():
return dict_to_yaml('EMBEDDER',
__class__.__name__,
ACEHFEmbedder.para_dict,
set_name=True)
@EMBEDDERS.register_class()
class T5ACEPlusClipFluxEmbedder(BaseEmbedder):
"""
Uses the OpenCLIP transformer encoder for text
"""
para_dict = {
'T5_MODEL': {},
'CLIP_MODEL': {}
}
def __init__(self, cfg, logger=None):
super().__init__(cfg, logger=logger)
self.t5_model = EMBEDDERS.build(cfg.T5_MODEL, logger=logger)
self.clip_model = EMBEDDERS.build(cfg.CLIP_MODEL, logger=logger)
def encode(self, text, return_mask = False):
t5_embeds = self.t5_model.encode(text, return_mask = return_mask)
clip_embeds = self.clip_model.encode(text, return_mask = return_mask)
# change embedding strategy here
return {
'context': t5_embeds,
'y': clip_embeds,
}
def encode_list(self, text, return_mask = False):
t5_embeds = self.t5_model.encode_list(text, return_mask = return_mask)
clip_embeds = self.clip_model.encode_list(text, return_mask = return_mask)
# change embedding strategy here
return {
'context': t5_embeds,
'y': clip_embeds,
}
def encode_list_of_list(self, text, return_mask = False):
t5_embeds = self.t5_model.encode_list_of_list(text, return_mask = return_mask)
clip_embeds = self.clip_model.encode_list_of_list(text, return_mask = return_mask)
# change embedding strategy here
return {
'context': t5_embeds,
'y': clip_embeds,
}
@staticmethod
def get_config_template():
return dict_to_yaml('EMBEDDER',
__class__.__name__,
T5ACEPlusClipFluxEmbedder.para_dict,
set_name=True)