File size: 3,392 Bytes
1a5a63d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# HunyuanDiT 蒸馏加速

语言: [**English**](https://huggingface.co/Tencent-Hunyuan/Distillation/blob/main/README.md) | **中文**

我们提供了蒸馏模型。 通过蒸馏可以降低扩散模型迭代的步数实现加速。

混元文生图大模型通过渐进式方法进行步数蒸馏,将原始DIT模型采样步数压缩50%,同时生成效果几乎无损。并且蒸馏模型可以在任何推理模式下使用。

下面的表格展示了使用蒸馏模型进行推理需要的设备,以及设备上的加速效果。这里我们在包括H800,A100, 3090, 4090等各种型号的gpu上, 以及torch 推理模式以及TensorRT模式上均进行了测速,供开发者参考。(batch size = 1)。

| GPU| CUDA version | model | inference mode | inference steps | GPU Peak Memory | inference time |
| --- | --- | --- | --- | --- | --- | --- |
| H800 | 12.1 | HunyuanDiT | PyTorch | 100 | 13G | 28s |
| H800 | 12.1 | HunyuanDiT | TensorRT | 100 | 12G | 10s |
| H800 | 12.1 | HunyuanDiT | Distill+PyTorch | 50 | 13G | 14s |
| H800 | 12.1 | HunyuanDiT | Distill+TensorRT | 50 | 12G | 5s |
| A100 | 11.7 | HunyuanDiT | PyTorch | 100 | 13GB |  54s |
| A100 | 11.7 | HunyuanDiT | TensorRT | 100 | 11GB | 20s |
| A100 | 11.7 | HunyuanDiT | Distill+PyTorch | 50 |  13GB | 25s |
| A100 | 11.7 | HunyuanDiT | Distill+TensorRT | 50 | 11GB | 10s |
| 3090 | 11.8 | HunyuanDiT | PyTorch | 100 | 14G | 98s |
| 3090 | 11.8 | HunyuanDiT | TensorRT | 100 | 14G | 40s |
| 3090 | 11.8 | HunyuanDiT | Distill+PyTorch | 50 | 14G | 49s |
| 3090 | 11.8 | HunyuanDiT | Distill+TensorRT | 50 | 14G | 20s |
| 4090 | 11.8 | HunyuanDiT | PyTorch | 100 | 14G | 54s |
| 4090 | 11.8 | HunyuanDiT | TensorRT | 100 | 14G | 20s |
| 4090 | 11.8 | HunyuanDiT | Distill+PyTorch | 50 | 14G | 27s |
| 4090 | 11.8 | HunyuanDiT | Distill+TensorRT | 50 | 14G | 10s |


## Instructions

 蒸馏模型推理所需要的安装包、依赖和 [**混元原始模型**](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT) 一致。

 模型下载使用下述指令:

```bash
cd HunyuanDiT
# 使用 huggingface-cli 工具下载模型.
huggingface-cli download Tencent-Hunyuan/Distillation-v1.2 ./pytorch_model_distill.pt --local-dir ./ckpts/t2i/model
```


## Inference

### Using Gradio

在执行如下命令前,请确保已经进入conda 环境。

```shell
# 默认指令,启动中文Gradio界面,界面中可以将采样步数降低至50步,出图效果基本无损耗
python app/hydit_app.py  --load-key distill

# 使用flash attention加速
python app/hydit_app.py --infer-mode fa --load-key distill

# 设置`--no-enhance`,prompt增强功能不可用,以降低显存占用. 
python app/hydit_app.py --no-enhance ---load-key distill

# 启动英文Gradio界面
python app/hydit_app.py --lang en --load-key distill

```


### Using Command Line

一些使用蒸馏模型的demo: 

```shell
# Prompt增强 + 文生图. Torch 模式
python sample_t2i.py --prompt "渔舟唱晚" --load-key distill  --infer-steps 50

# 文生图. Torch 模式
python sample_t2i.py --prompt "渔舟唱晚" --no-enhance --load-key distill  --infer-steps 50

# 文生图. Flash Attention 模式
python sample_t2i.py --infer-mode fa --prompt "渔舟唱晚" --load-key distill --infer-steps 50

# 切换生图分辨率.
python sample_t2i.py --prompt "渔舟唱晚" --image-size 1280 768 --load-key distill  --infer-steps 50
```