zwv9 commited on
Commit
d6b300f
1 Parent(s): 750d549
Files changed (1) hide show
  1. wfx.py +9 -2
wfx.py CHANGED
@@ -21,6 +21,7 @@ def parse_args():
21
  args.add_argument('--quantize-unet', action='store_true', default=False)
22
  args.add_argument('--model', type=str, required=True)
23
  args.add_argument('--custom-pipeline', type=str, default=None)
 
24
  return args.parse_args()
25
 
26
  def quantize_unet(m):
@@ -91,8 +92,14 @@ class WFX():
91
  if args.quantize_unet:
92
  self.T2IPipeline.unet = quantize_unet(self.T2IPipeline.unet)
93
 
94
- logger.info('compiling model...')
95
- self.T2IPipeline = compile(self.T2IPipeline, self.compiler_config)
 
 
 
 
 
 
96
 
97
  self.warmup()
98
 
 
21
  args.add_argument('--quantize-unet', action='store_true', default=False)
22
  args.add_argument('--model', type=str, required=True)
23
  args.add_argument('--custom-pipeline', type=str, default=None)
24
+ args.add_argument('--compile-mode', default='sfast', type=str, choices=['sfast', 'torch', 'no-compile'])
25
  return args.parse_args()
26
 
27
  def quantize_unet(m):
 
92
  if args.quantize_unet:
93
  self.T2IPipeline.unet = quantize_unet(self.T2IPipeline.unet)
94
 
95
+ logger.info(f'compiling pipeline in {args.compile_mode} mode...')
96
+ if args.compile_mode == 'sfast':
97
+ self.T2IPipeline = compile(self.T2IPipeline, self.compiler_config)
98
+ elif args.compile_mode == 'torch':
99
+ logger.info('compiling unet...')
100
+ self.T2IPipeline.unet = torch.compile(self.T2IPipeline.unet, mode='max-autotune')
101
+ logger.info('compiling vae...')
102
+ self.T2IPipeline.vae = torch.compile(self.T2IPipeline.vae, mode='max-autotune')
103
 
104
  self.warmup()
105