zwv9 commited on
Commit
8c4b798
1 Parent(s): 753e002

forward to gpu before quant

Browse files
Files changed (1) hide show
  1. wfx.py +2 -3
wfx.py CHANGED
@@ -86,7 +86,7 @@ class WFX():
86
 
87
  self.T2IPipeline = AutoPipelineForText2Image.from_pretrained(args.model, **extra_kwargs)
88
  self.T2IPipeline.safety_checker = None
89
- # self.T2IPipeline.to(torch.device('cuda:0'))
90
 
91
  if args.quantize_unet:
92
  self.T2IPipeline.unet = quantize_unet(self.T2IPipeline.unet)
@@ -94,7 +94,6 @@ class WFX():
94
  logger.info('compiling model...')
95
  self.T2IPipeline = compile(self.T2IPipeline, self.compiler_config)
96
 
97
- self.T2IPipeline.to(torch.device('cuda:0'))
98
  self.warmup()
99
 
100
  def warmup(self) -> None:
@@ -103,7 +102,7 @@ class WFX():
103
 
104
  warmup_kwargs = dict(
105
  prompt='a photo of a cat',
106
- height=768,
107
  width=512,
108
  num_inference_steps=30,
109
  generator=torch.Generator(device='cuda:0').manual_seed(0),
 
86
 
87
  self.T2IPipeline = AutoPipelineForText2Image.from_pretrained(args.model, **extra_kwargs)
88
  self.T2IPipeline.safety_checker = None
89
+ self.T2IPipeline.to(torch.device('cuda:0'))
90
 
91
  if args.quantize_unet:
92
  self.T2IPipeline.unet = quantize_unet(self.T2IPipeline.unet)
 
94
  logger.info('compiling model...')
95
  self.T2IPipeline = compile(self.T2IPipeline, self.compiler_config)
96
 
 
97
  self.warmup()
98
 
99
  def warmup(self) -> None:
 
102
 
103
  warmup_kwargs = dict(
104
  prompt='a photo of a cat',
105
+ height=512,
106
  width=512,
107
  num_inference_steps=30,
108
  generator=torch.Generator(device='cuda:0').manual_seed(0),