cocktailpeanut commited on
Commit
e9f320b
1 Parent(s): 9f53447
Files changed (1) hide show
  1. app.py +16 -7
app.py CHANGED
@@ -8,7 +8,7 @@ from typing import List
8
  from diffusers.utils import numpy_to_pil
9
  from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
10
  from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
11
- import spaces
12
  from previewer.modules import Previewer
13
  import user_history
14
 
@@ -16,19 +16,28 @@ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
16
 
17
  DESCRIPTION = "# Stable Cascade"
18
  DESCRIPTION += "\n<p style=\"text-align: center\">Unofficial demo for <a href='https://huggingface.co/stabilityai/stable-cascade' target='_blank'>Stable Casacade</a>, a new high resolution text-to-image model by Stability AI, built on the Würstchen architecture - <a href='https://huggingface.co/stabilityai/stable-cascade/blob/main/LICENSE' target='_blank'>non-commercial research license</a></p>"
19
- if not torch.cuda.is_available():
20
- DESCRIPTION += "\n<p>Running on CPU 🥶</p>"
21
 
22
  MAX_SEED = np.iinfo(np.int32).max
23
- CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") != "0"
 
24
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1536"))
25
  USE_TORCH_COMPILE = False
26
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
27
  PREVIEW_IMAGES = True
28
 
29
  dtype = torch.bfloat16
30
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
31
  if torch.cuda.is_available():
 
 
 
 
 
 
 
 
32
  prior_pipeline = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=dtype)#.to(device)
33
  decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype)#.to(device)
34
 
@@ -66,7 +75,7 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
66
  seed = random.randint(0, MAX_SEED)
67
  return seed
68
 
69
- @spaces.GPU
70
  def generate(
71
  prompt: str,
72
  negative_prompt: str = "",
@@ -276,4 +285,4 @@ with gr.Blocks(css="style.css") as demo_with_history:
276
  user_history.render()
277
 
278
  if __name__ == "__main__":
279
- demo_with_history.queue(max_size=20).launch()
 
8
  from diffusers.utils import numpy_to_pil
9
  from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
10
  from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
11
+ #import spaces
12
  from previewer.modules import Previewer
13
  import user_history
14
 
 
16
 
17
  DESCRIPTION = "# Stable Cascade"
18
  DESCRIPTION += "\n<p style=\"text-align: center\">Unofficial demo for <a href='https://huggingface.co/stabilityai/stable-cascade' target='_blank'>Stable Casacade</a>, a new high resolution text-to-image model by Stability AI, built on the Würstchen architecture - <a href='https://huggingface.co/stabilityai/stable-cascade/blob/main/LICENSE' target='_blank'>non-commercial research license</a></p>"
19
+ #if not torch.cuda.is_available():
20
+ # DESCRIPTION += "\n<p>Running on CPU 🥶</p>"
21
 
22
  MAX_SEED = np.iinfo(np.int32).max
23
+ CACHE_EXAMPLES = False
24
+ #CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") != "0"
25
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1536"))
26
  USE_TORCH_COMPILE = False
27
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
28
  PREVIEW_IMAGES = True
29
 
30
  dtype = torch.bfloat16
31
+ #device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
  if torch.cuda.is_available():
33
+ device = "cuda"
34
+ elif torch.backends.mps.is_available():
35
+ device = "mps"
36
+ dtype = torch.float32
37
+ else:
38
+ device = "cpu"
39
+ #if torch.cuda.is_available():
40
+ if device != "cpu":
41
  prior_pipeline = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=dtype)#.to(device)
42
  decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype)#.to(device)
43
 
 
75
  seed = random.randint(0, MAX_SEED)
76
  return seed
77
 
78
+ #@spaces.GPU
79
  def generate(
80
  prompt: str,
81
  negative_prompt: str = "",
 
285
  user_history.render()
286
 
287
  if __name__ == "__main__":
288
+ demo_with_history.queue(max_size=20).launch()