|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
|
|
|
import torch |
|
|
from diffusers import AutoencoderKL, DiffusionPipeline, EulerDiscreteScheduler |
|
|
from huggingface_hub import snapshot_download |
|
|
from kolors.models.controlnet import ControlNetModel |
|
|
from kolors.models.modeling_chatglm import ChatGLMModel |
|
|
from kolors.models.tokenization_chatglm import ChatGLMTokenizer |
|
|
from kolors.models.unet_2d_condition import UNet2DConditionModel |
|
|
from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import ( |
|
|
StableDiffusionXLControlNetImg2ImgPipeline, |
|
|
) |
|
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection |
|
|
from embodied_gen.models.text_model import download_kolors_weights |
|
|
from embodied_gen.utils.log import logger |
|
|
|
|
|
__all__ = [ |
|
|
"build_texture_gen_pipe", |
|
|
] |
|
|
|
|
|
|
|
|
def build_texture_gen_pipe( |
|
|
base_ckpt_dir: str, |
|
|
controlnet_ckpt: str = None, |
|
|
ip_adapt_scale: float = 0, |
|
|
device: str = "cuda", |
|
|
) -> DiffusionPipeline: |
|
|
download_kolors_weights(f"{base_ckpt_dir}/Kolors") |
|
|
logger.info(f"Load Kolors weights...") |
|
|
tokenizer = ChatGLMTokenizer.from_pretrained( |
|
|
f"{base_ckpt_dir}/Kolors/text_encoder" |
|
|
) |
|
|
text_encoder = ChatGLMModel.from_pretrained( |
|
|
f"{base_ckpt_dir}/Kolors/text_encoder", torch_dtype=torch.float16 |
|
|
).half() |
|
|
vae = AutoencoderKL.from_pretrained( |
|
|
f"{base_ckpt_dir}/Kolors/vae", revision=None |
|
|
).half() |
|
|
unet = UNet2DConditionModel.from_pretrained( |
|
|
f"{base_ckpt_dir}/Kolors/unet", revision=None |
|
|
).half() |
|
|
scheduler = EulerDiscreteScheduler.from_pretrained( |
|
|
f"{base_ckpt_dir}/Kolors/scheduler" |
|
|
) |
|
|
|
|
|
if controlnet_ckpt is None: |
|
|
suffix = "texture_gen_mv_v1" |
|
|
model_path = snapshot_download( |
|
|
repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*" |
|
|
) |
|
|
controlnet_ckpt = os.path.join(model_path, suffix) |
|
|
|
|
|
controlnet = ControlNetModel.from_pretrained( |
|
|
controlnet_ckpt, use_safetensors=True |
|
|
).half() |
|
|
|
|
|
|
|
|
image_encoder = None |
|
|
clip_image_processor = None |
|
|
if ip_adapt_scale > 0: |
|
|
image_encoder = CLIPVisionModelWithProjection.from_pretrained( |
|
|
f"{base_ckpt_dir}/Kolors-IP-Adapter-Plus/image_encoder", |
|
|
|
|
|
).to(dtype=torch.float16) |
|
|
ip_img_size = 336 |
|
|
clip_image_processor = CLIPImageProcessor( |
|
|
size=ip_img_size, crop_size=ip_img_size |
|
|
) |
|
|
|
|
|
pipe = StableDiffusionXLControlNetImg2ImgPipeline( |
|
|
vae=vae, |
|
|
controlnet=controlnet, |
|
|
text_encoder=text_encoder, |
|
|
tokenizer=tokenizer, |
|
|
unet=unet, |
|
|
scheduler=scheduler, |
|
|
image_encoder=image_encoder, |
|
|
feature_extractor=clip_image_processor, |
|
|
force_zeros_for_empty_prompt=False, |
|
|
) |
|
|
|
|
|
if ip_adapt_scale > 0: |
|
|
if hasattr(pipe.unet, "encoder_hid_proj"): |
|
|
pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj |
|
|
pipe.load_ip_adapter( |
|
|
f"{base_ckpt_dir}/Kolors-IP-Adapter-Plus", |
|
|
subfolder="", |
|
|
weight_name=["ip_adapter_plus_general.bin"], |
|
|
) |
|
|
pipe.set_ip_adapter_scale([ip_adapt_scale]) |
|
|
|
|
|
pipe = pipe.to(device) |
|
|
|
|
|
|
|
|
return pipe |
|
|
|