hysts HF staff commited on
Commit
f5e83aa
1 Parent(s): 93fd2ea
Files changed (1) hide show
  1. model.py +5 -11
model.py CHANGED
@@ -11,17 +11,11 @@ class Model:
11
  self.device = torch.device(
12
  'cuda:0' if torch.cuda.is_available() else 'cpu')
13
  model_id = 'CompVis/stable-diffusion-v1-4'
14
- if self.device.type == 'cuda':
15
- self.ax_pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained(
16
- model_id, torch_dtype=torch.float16)
17
- self.ax_pipe.to(self.device)
18
- self.sd_pipe = StableDiffusionPipeline.from_pretrained(
19
- model_id, torch_dtype=torch.float16)
20
- self.sd_pipe.to(self.device)
21
- else:
22
- self.ax_pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained(
23
- model_id)
24
- self.sd_pipe = StableDiffusionPipeline.from_pretrained(model_id)
25
 
26
  def get_token_table(self, prompt: str):
27
  tokens = [
 
11
  self.device = torch.device(
12
  'cuda:0' if torch.cuda.is_available() else 'cpu')
13
  model_id = 'CompVis/stable-diffusion-v1-4'
14
+ self.ax_pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained(
15
+ model_id)
16
+ self.ax_pipe.to(self.device)
17
+ self.sd_pipe = StableDiffusionPipeline.from_pretrained(model_id)
18
+ self.sd_pipe.to(self.device)
 
 
 
 
 
 
19
 
20
  def get_token_table(self, prompt: str):
21
  tokens = [