sohojoe commited on
Commit
3adda7c
1 Parent(s): 7d22699

use 16bit only if cuda

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -94,7 +94,7 @@ def main(
94
  generator = torch.Generator().manual_seed(int(seed)) # use cpu as does not work on mps
95
 
96
  embeddings = base64_to_embedding(embeddings)
97
- embeddings = torch.tensor(embeddings).to(device)
98
 
99
  images_list = pipe(
100
  # inp.tile(n_samples, 1, 1, 1),
@@ -178,12 +178,12 @@ def on_embeddings_changed_update_plot(embeddings_b64):
178
  # height=300,
179
  width=embeddings.shape[0])
180
 
181
-
182
  device = torch.device("mps" if torch.backends.mps.is_available() else "cuda:0" if torch.cuda.is_available() else "cpu")
 
183
  pipe = StableDiffusionPipeline.from_pretrained(
184
  model_id,
185
  custom_pipeline="pipeline.py",
186
- torch_dtype=torch.float16,
187
  # , revision="fp16",
188
  requires_safety_checker = False, safety_checker=None,
189
  text_encoder = CLIPTextModel,
 
94
  generator = torch.Generator().manual_seed(int(seed)) # use cpu as does not work on mps
95
 
96
  embeddings = base64_to_embedding(embeddings)
97
+ embeddings = torch.tensor(embeddings, dtype=torch_size).to(device)
98
 
99
  images_list = pipe(
100
  # inp.tile(n_samples, 1, 1, 1),
 
178
  # height=300,
179
  width=embeddings.shape[0])
180
 
 
181
  device = torch.device("mps" if torch.backends.mps.is_available() else "cuda:0" if torch.cuda.is_available() else "cpu")
182
+ torch_size = torch.float16 if device == ('cuda') else torch.float32
183
  pipe = StableDiffusionPipeline.from_pretrained(
184
  model_id,
185
  custom_pipeline="pipeline.py",
186
+ torch_dtype=torch_size,
187
  # , revision="fp16",
188
  requires_safety_checker = False, safety_checker=None,
189
  text_encoder = CLIPTextModel,