update diffuser usage
Browse files
README.md
CHANGED
@@ -27,7 +27,48 @@ Our 4-step (much higher quality, 2X slower) Text-to-Image demo is hosted at [DMD
|
|
27 |
Our 1-step Text-to-Image demo is hosted at [DMD2-1step](https://154dfe6ee5c63946cc.gradio.live)
|
28 |
|
29 |
## Usage
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
|
33 |
## License
|
|
|
27 |
Our 1-step Text-to-Image demo is hosted at [DMD2-1step](https://154dfe6ee5c63946cc.gradio.live)
|
28 |
|
29 |
## Usage
|
30 |
+
|
31 |
+
We can use the standard diffuser pipeline:
|
32 |
+
|
33 |
+
#### 4-step generation
|
34 |
+
|
35 |
+
```.bash
|
36 |
+
import torch
|
37 |
+
from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler
|
38 |
+
from huggingface_hub import hf_hub_download
|
39 |
+
from safetensors.torch import load_file
|
40 |
+
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
41 |
+
repo_name = "tianweiy/DMD2"
|
42 |
+
ckpt_name = "dmd2_sdxl_4step_unet.bin"
|
43 |
+
# Load model.
|
44 |
+
unet = UNet2DConditionModel.from_config(base_model_id, subfolder="unet").to("cuda", torch.float16)
|
45 |
+
unet.load_state_dict(torch.load(hf_hub_download(repo_name, ckpt_name), map_location="cuda"))
|
46 |
+
pipe = DiffusionPipeline.from_pretrained(base_model_id, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
|
47 |
+
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
48 |
+
prompt="a photo of a cat"
|
49 |
+
image=pipe(prompt=prompt, num_inference_steps=4, guidance_scale=0).images[0]
|
50 |
+
```
|
51 |
+
|
52 |
+
#### 1-step generation
|
53 |
+
|
54 |
+
```.bash
|
55 |
+
import torch
|
56 |
+
from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler
|
57 |
+
from huggingface_hub import hf_hub_download
|
58 |
+
from safetensors.torch import load_file
|
59 |
+
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
60 |
+
repo_name = "tianweiy/DMD2"
|
61 |
+
ckpt_name = "dmd2_sdxl_1step_unet.bin"
|
62 |
+
# Load model.
|
63 |
+
unet = UNet2DConditionModel.from_config(base_model_id, subfolder="unet").to("cuda", torch.float16)
|
64 |
+
unet.load_state_dict(torch.load(hf_hub_download(repo_name, ckpt_name), map_location="cuda"))
|
65 |
+
pipe = DiffusionPipeline.from_pretrained(base_model_id, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
|
66 |
+
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
67 |
+
prompt="a photo of a cat"
|
68 |
+
image=pipe(prompt=prompt, num_inference_steps=1, guidance_scale=0, timesteps=[399]).images[0]
|
69 |
+
```
|
70 |
+
|
71 |
+
For more information, please refer to the [code repository](https://github.com/tianweiy/DMD2)
|
72 |
|
73 |
|
74 |
## License
|