Spaces:
Running
on
Zero
Running
on
Zero
# -*- coding: utf-8 -*- | |
# Copyright (c) Alibaba, Inc. and its affiliates. | |
# This file contains code that is adapted from | |
# https://github.com/black-forest-labs/flux.git | |
import warnings | |
import torch | |
import torch.utils.dlpack | |
import transformers | |
from scepter.modules.model.embedder.base_embedder import BaseEmbedder | |
from scepter.modules.model.registry import EMBEDDERS | |
from scepter.modules.model.tokenizer.tokenizer_component import ( | |
basic_clean, canonicalize, whitespace_clean) | |
from scepter.modules.utils.config import dict_to_yaml | |
from scepter.modules.utils.file_system import FS | |
try: | |
from transformers import AutoTokenizer, T5EncoderModel | |
except Exception as e: | |
warnings.warn( | |
f'Import transformers error, please deal with this problem: {e}') | |
class ACEHFEmbedder(BaseEmbedder): | |
para_dict = { | |
"HF_MODEL_CLS": { | |
"value": None, | |
"description": "huggingface cls in transfomer" | |
}, | |
"MODEL_PATH": { | |
"value": None, | |
"description": "model folder path" | |
}, | |
"HF_TOKENIZER_CLS": { | |
"value": None, | |
"description": "huggingface cls in transfomer" | |
}, | |
"TOKENIZER_PATH": { | |
"value": None, | |
"description": "tokenizer folder path" | |
}, | |
"MAX_LENGTH": { | |
"value": 77, | |
"description": "max length of input" | |
}, | |
"OUTPUT_KEY": { | |
"value": "last_hidden_state", | |
"description": "output key" | |
}, | |
"D_TYPE": { | |
"value": "float", | |
"description": "dtype" | |
}, | |
"BATCH_INFER": { | |
"value": False, | |
"description": "batch infer" | |
} | |
} | |
para_dict.update(BaseEmbedder.para_dict) | |
def __init__(self, cfg, logger=None): | |
super().__init__(cfg, logger=logger) | |
hf_model_cls = cfg.get('HF_MODEL_CLS', None) | |
model_path = cfg.get("MODEL_PATH", None) | |
hf_tokenizer_cls = cfg.get('HF_TOKENIZER_CLS', None) | |
tokenizer_path = cfg.get('TOKENIZER_PATH', None) | |
self.max_length = cfg.get('MAX_LENGTH', 77) | |
self.output_key = cfg.get("OUTPUT_KEY", "last_hidden_state") | |
self.d_type = cfg.get("D_TYPE", "float") | |
self.clean = cfg.get("CLEAN", "whitespace") | |
self.batch_infer = cfg.get("BATCH_INFER", False) | |
self.added_identifier = cfg.get('ADDED_IDENTIFIER', None) | |
torch_dtype = getattr(torch, self.d_type) | |
assert hf_model_cls is not None and hf_tokenizer_cls is not None | |
assert model_path is not None and tokenizer_path is not None | |
with FS.get_dir_to_local_dir(tokenizer_path, wait_finish=True) as local_path: | |
self.tokenizer = getattr(transformers, hf_tokenizer_cls).from_pretrained(local_path, | |
max_length = self.max_length, | |
torch_dtype = torch_dtype, | |
additional_special_tokens=self.added_identifier) | |
with FS.get_dir_to_local_dir(model_path, wait_finish=True) as local_path: | |
self.hf_module = getattr(transformers, hf_model_cls).from_pretrained(local_path, torch_dtype = torch_dtype) | |
self.hf_module = self.hf_module.eval().requires_grad_(False) | |
def forward(self, text: list[str], return_mask = False): | |
batch_encoding = self.tokenizer( | |
text, | |
truncation=True, | |
max_length=self.max_length, | |
return_length=False, | |
return_overflowing_tokens=False, | |
padding="max_length", | |
return_tensors="pt", | |
) | |
outputs = self.hf_module( | |
input_ids=batch_encoding["input_ids"].to(self.hf_module.device), | |
attention_mask=None, | |
output_hidden_states=False, | |
) | |
if return_mask: | |
return outputs[self.output_key], batch_encoding['attention_mask'].to(self.hf_module.device) | |
else: | |
return outputs[self.output_key], None | |
def encode(self, text, return_mask = False): | |
if isinstance(text, str): | |
text = [text] | |
if self.clean: | |
text = [self._clean(u) for u in text] | |
if not self.batch_infer: | |
cont, mask = [], [] | |
for tt in text: | |
one_cont, one_mask = self([tt], return_mask=return_mask) | |
cont.append(one_cont) | |
mask.append(one_mask) | |
if return_mask: | |
return torch.cat(cont, dim=0), torch.cat(mask, dim=0) | |
else: | |
return torch.cat(cont, dim=0) | |
else: | |
ret_data = self(text, return_mask = return_mask) | |
if return_mask: | |
return ret_data | |
else: | |
return ret_data[0] | |
def encode_list(self, text_list, return_mask=True): | |
cont_list = [] | |
mask_list = [] | |
for pp in text_list: | |
cont = self.encode(pp, return_mask=return_mask) | |
cont_list.append(cont[0]) if return_mask else cont_list.append(cont) | |
mask_list.append(cont[1]) if return_mask else mask_list.append(None) | |
if return_mask: | |
return cont_list, mask_list | |
else: | |
return cont_list | |
def encode_list_of_list(self, text_list, return_mask=True): | |
cont_list = [] | |
mask_list = [] | |
for pp in text_list: | |
cont = self.encode_list(pp, return_mask=return_mask) | |
cont_list.append(cont[0]) if return_mask else cont_list.append(cont) | |
mask_list.append(cont[1]) if return_mask else mask_list.append(None) | |
if return_mask: | |
return cont_list, mask_list | |
else: | |
return cont_list | |
def _clean(self, text): | |
if self.clean == 'whitespace': | |
text = whitespace_clean(basic_clean(text)) | |
elif self.clean == 'lower': | |
text = whitespace_clean(basic_clean(text)).lower() | |
elif self.clean == 'canonicalize': | |
text = canonicalize(basic_clean(text)) | |
return text | |
def get_config_template(): | |
return dict_to_yaml('EMBEDDER', | |
__class__.__name__, | |
ACEHFEmbedder.para_dict, | |
set_name=True) | |
class T5ACEPlusClipFluxEmbedder(BaseEmbedder): | |
""" | |
Uses the OpenCLIP transformer encoder for text | |
""" | |
para_dict = { | |
'T5_MODEL': {}, | |
'CLIP_MODEL': {} | |
} | |
def __init__(self, cfg, logger=None): | |
super().__init__(cfg, logger=logger) | |
self.t5_model = EMBEDDERS.build(cfg.T5_MODEL, logger=logger) | |
self.clip_model = EMBEDDERS.build(cfg.CLIP_MODEL, logger=logger) | |
def encode(self, text, return_mask = False): | |
t5_embeds = self.t5_model.encode(text, return_mask = return_mask) | |
clip_embeds = self.clip_model.encode(text, return_mask = return_mask) | |
# change embedding strategy here | |
return { | |
'context': t5_embeds, | |
'y': clip_embeds, | |
} | |
def encode_list(self, text, return_mask = False): | |
t5_embeds = self.t5_model.encode_list(text, return_mask = return_mask) | |
clip_embeds = self.clip_model.encode_list(text, return_mask = return_mask) | |
# change embedding strategy here | |
return { | |
'context': t5_embeds, | |
'y': clip_embeds, | |
} | |
def encode_list_of_list(self, text, return_mask = False): | |
t5_embeds = self.t5_model.encode_list_of_list(text, return_mask = return_mask) | |
clip_embeds = self.clip_model.encode_list_of_list(text, return_mask = return_mask) | |
# change embedding strategy here | |
return { | |
'context': t5_embeds, | |
'y': clip_embeds, | |
} | |
def get_config_template(): | |
return dict_to_yaml('EMBEDDER', | |
__class__.__name__, | |
T5ACEPlusClipFluxEmbedder.para_dict, | |
set_name=True) |