Distillation / README_zh.md
Zhiminli's picture
Upload folder using huggingface_hub
8ce1075 verified

HunyuanDiT 蒸馏加速

语言: English | 中文

我们提供了蒸馏模型。 通过蒸馏可以降低扩散模型迭代的步数实现加速。

混元文生图大模型通过渐进式方法进行步数蒸馏,将原始DIT模型采样步数压缩50%,同时生成效果几乎无损。并且蒸馏模型可以在任何推理模式下使用。

下面的表格展示了使用蒸馏模型进行推理需要的设备,以及设备上的加速效果。这里我们在包括H800,A100, 3090, 4090等各种型号的gpu上, 以及torch 推理模式以及TensorRT模式上均进行了测速,供开发者参考。(batch size = 1)。

GPU CUDA version model inference mode inference steps GPU Peak Memory inference time
H800 12.1 HunyuanDiT PyTorch 100 13G 28s
H800 12.1 HunyuanDiT TensorRT 100 12G 10s
H800 12.1 HunyuanDiT Distill+PyTorch 50 13G 14s
H800 12.1 HunyuanDiT Distill+TensorRT 50 12G 5s
A100 11.7 HunyuanDiT PyTorch 100 13GB 54s
A100 11.7 HunyuanDiT TensorRT 100 11GB 20s
A100 11.7 HunyuanDiT Distill+PyTorch 50 13GB 25s
A100 11.7 HunyuanDiT Distill+TensorRT 50 11GB 10s
3090 11.8 HunyuanDiT PyTorch 100 14G 98s
3090 11.8 HunyuanDiT TensorRT 100 14G 40s
3090 11.8 HunyuanDiT Distill+PyTorch 50 14G 49s
3090 11.8 HunyuanDiT Distill+TensorRT 50 14G 20s
4090 11.8 HunyuanDiT PyTorch 100 14G 54s
4090 11.8 HunyuanDiT TensorRT 100 14G 20s
4090 11.8 HunyuanDiT Distill+PyTorch 50 14G 27s
4090 11.8 HunyuanDiT Distill+TensorRT 50 14G 10s

Instructions

蒸馏模型推理所需要的安装包、依赖和 混元原始模型 一致。

模型下载使用下述指令:

cd HunyuanDiT
# 使用 huggingface-cli 工具下载模型.
huggingface-cli download Tencent-Hunyuan/Distillation ./pytorch_model_distill.pt --local-dir ./ckpts/t2i/model

Inference

Using Gradio

在执行如下命令前,请确保已经进入conda 环境。

# 默认指令,启动中文Gradio界面,界面中可以将采样步数降低至50步,出图效果基本无损耗
python app/hydit_app.py  --load-key distill

# 使用flash attention加速
python app/hydit_app.py --infer-mode fa --load-key distill

# 设置`--no-enhance`,prompt增强功能不可用,以降低显存占用. 
python app/hydit_app.py --no-enhance ---load-key distill

# 启动英文Gradio界面
python app/hydit_app.py --lang en --load-key distill

Using Command Line

一些使用蒸馏模型的demo:

# Prompt增强 + 文生图. Torch 模式
python sample_t2i.py --prompt "渔舟唱晚" --load-key distill  --infer-steps 50

# 文生图. Torch 模式
python sample_t2i.py --prompt "渔舟唱晚" --no-enhance --load-key distill  --infer-steps 50

# 文生图. Flash Attention 模式
python sample_t2i.py --infer-mode fa --prompt "渔舟唱晚" --load-key distill --infer-steps 50

# 切换生图分辨率.
python sample_t2i.py --prompt "渔舟唱晚" --image-size 1280 768 --load-key distill  --infer-steps 50