zwv9 commited on
Commit
877e614
1 Parent(s): 5925c1a
Files changed (4) hide show
  1. .gitignore +1 -0
  2. pipeline.py +11 -0
  3. requirements.txt +10 -0
  4. wfx.py +118 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ env/
pipeline.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re, logging, torch
2
+ import numpy as np
3
+ from typing import List, Optional, Union
4
+ from diffusers import DiffusionPipeline
5
+ from diffusers.utils import (PIL_INTERPOLATION)
6
+
7
+ # ------------------------- #
8
+
9
+ logger = logging.get_logger(__name__)
10
+
11
+ # ---------- LPW ---------- #
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ asyncio==3.4.3
2
+ stable-fast==1.0.1
3
+ torch==2.1.2
4
+ torchvision==0.16.2
5
+ triton==2.1.0
6
+ xformers==0.0.23.post1
7
+ packaging==23.2
8
+ diffusers==0.25.1
9
+ peft==0.7.1
10
+ k-diffusion==0.1.1.post1
wfx.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, torch, logging
2
+ import packaging.version as pv
3
+ from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
4
+ from sfast.compilers.diffusion_pipeline_compiler import (compile, CompilationConfig)
5
+
6
+
7
+ logging.basicConfig(level=logging.INFO, format='%(name)s - %(levelname)s - %(message)s')
8
+ logger = logging.getLogger('wfx')
9
+
10
+
11
+ if pv.parse(torch.__version__) >= pv.parse('1.12.0'):
12
+ torch.backends.cuda.matmul.allow_tf32 = True
13
+ torch.backends.cudnn.allow_tf32 = True # not sure...
14
+ logger.info('matching torch version, enabling tf32')
15
+
16
+
17
+ def parse_args():
18
+ args = argparse.ArgumentParser()
19
+ args.add_argument('--disable-xformers', action='store_true', default=False)
20
+ args.add_argument('--disable-triton', action='store_true', default=False)
21
+ args.add_argument('--quantize-unet', action='store_true', default=False)
22
+ args.add_argument('--model', type=str, required=True)
23
+ args.add_argument('--custom-pipeline', type=str, default=None)
24
+ return args.parse_args()
25
+
26
+ def quantize_unet(m):
27
+ from diffusers.utils import USE_PEFT_BACKEND
28
+ assert USE_PEFT_BACKEND
29
+
30
+ logger.info('PEFT backend detected, quantizing unet...')
31
+
32
+ m = torch.quantization.quantize_dynamic(
33
+ m, { torch.nn.Linear },
34
+ dtype=torch.qint8,
35
+ inplace=True
36
+ )
37
+
38
+ logger.info('unet successfully quantized')
39
+ return m
40
+
41
+
42
+ class WFX():
43
+ compiler_config: CompilationConfig.Default = CompilationConfig.Default()
44
+ T2IPipeline: AutoPipelineForText2Image = None
45
+ I2IPipeline: AutoPipelineForImage2Image = None
46
+
47
+ def __init__(self) -> None:
48
+ args = parse_args()
49
+ self._check_optimization(args)
50
+
51
+ def _check_optimization(self, args) -> None:
52
+ logger.info(f'torch version: {torch.__version__}')
53
+
54
+ if not args.disable_xformers:
55
+ try:
56
+ import xformers
57
+ self.compiler_config.enable_xformers = True
58
+ logger.info(f'xformers version: {xformers.__version__}')
59
+ except ImportError:
60
+ logger.warning('xformers not found, disabling xformers')
61
+
62
+ if not args.disable_triton:
63
+ try:
64
+ import triton
65
+ self.compiler_config.enable_triton = True
66
+ logger.info(f'triton version: {triton.__version__}')
67
+ except ImportError:
68
+ logger.warning('triton not found, disabling triton')
69
+
70
+ self.compiler_config.enable_cuda_graph = True
71
+
72
+ for key in self.compiler_config.__dict__:
73
+ logger.info(f'cc - {key}: {self.compiler_config.__dict__[key]}')
74
+
75
+ def load(self) -> None:
76
+ args = parse_args()
77
+ extra_kwargs = {
78
+ 'torch_dtype': torch.float16,
79
+ 'use_safetensors': True,
80
+ 'requires_safety_checker': False,
81
+ }
82
+
83
+ if args.custom_pipeline is not None:
84
+ logger.info(f'loading custom pipeline from "{args.custom_pipeline}"')
85
+ extra_kwargs['custom_pipeline'] = args.custom_pipeline
86
+
87
+ self.T2IPipeline = AutoPipelineForText2Image.from_pretrained(args.model, **extra_kwargs)
88
+ self.T2IPipeline.safety_checker = None
89
+ # self.T2IPipeline.to(torch.device('cuda:0'))
90
+
91
+ if args.quantize_unet:
92
+ self.T2IPipeline.unet = quantize_unet(self.T2IPipeline.unet)
93
+
94
+ logger.info('compiling model...')
95
+ self.T2IPipeline = compile(self.T2IPipeline, self.compiler_config)
96
+
97
+ self.T2IPipeline.to(torch.device('cuda:0'))
98
+ self.warmup()
99
+
100
+ def warmup(self) -> None:
101
+ warmed = 5
102
+ warmup_kwargs = dict(
103
+ prompt='a photo of a cat',
104
+ height=768,
105
+ width=512,
106
+ num_inference_steps=30,
107
+ generator=torch.Generator(device='cuda:0').manual_seed(0),
108
+ )
109
+
110
+ if warmed > 0:
111
+ logger.info(f'warming up T2I pipeline for {warmed} steps')
112
+ self.T2IPipeline(**warmup_kwargs)
113
+ warmed -= 1
114
+
115
+
116
+ if __name__ == '__main__':
117
+ wfx = WFX()
118
+ wfx.load()