hysts HF staff commited on
Commit
7aa55b4
1 Parent(s): 98d1c03

Update for ZeroGPU

Browse files
Files changed (2) hide show
  1. app.py +2 -1
  2. requirements.txt +1 -0
app.py CHANGED
@@ -8,6 +8,7 @@ import random
8
  import gradio as gr
9
  import numpy as np
10
  import PIL.Image
 
11
  import torch
12
  from diffusers import StableDiffusionAttendAndExcitePipeline, StableDiffusionPipeline
13
 
@@ -47,6 +48,7 @@ def get_token_table(prompt: str) -> list[tuple[int, str]]:
47
  return list(enumerate(tokens, start=1))
48
 
49
 
 
50
  def run(
51
  prompt: str,
52
  indices_to_alter_str: str,
@@ -60,7 +62,6 @@ def run(
60
  20: 0.8,
61
  },
62
  max_iter_to_alter: int = 25,
63
- progress=gr.Progress(track_tqdm=True),
64
  ) -> PIL.Image.Image:
65
  if num_inference_steps > MAX_INFERENCE_STEPS:
66
  raise gr.Error(f"Number of steps cannot exceed {MAX_INFERENCE_STEPS}.")
 
8
  import gradio as gr
9
  import numpy as np
10
  import PIL.Image
11
+ import spaces
12
  import torch
13
  from diffusers import StableDiffusionAttendAndExcitePipeline, StableDiffusionPipeline
14
 
 
48
  return list(enumerate(tokens, start=1))
49
 
50
 
51
+ @spaces.GPU
52
  def run(
53
  prompt: str,
54
  indices_to_alter_str: str,
 
62
  20: 0.8,
63
  },
64
  max_iter_to_alter: int = 25,
 
65
  ) -> PIL.Image.Image:
66
  if num_inference_steps > MAX_INFERENCE_STEPS:
67
  raise gr.Error(f"Number of steps cannot exceed {MAX_INFERENCE_STEPS}.")
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
  accelerate==0.23.0
2
  diffusers==0.21.2
3
  Pillow==10.0.1
 
4
  torch==2.0.0
5
  torchvision==0.15.1
6
  transformers==4.33.2
 
1
  accelerate==0.23.0
2
  diffusers==0.21.2
3
  Pillow==10.0.1
4
+ spaces==0.14.0
5
  torch==2.0.0
6
  torchvision==0.15.1
7
  transformers==4.33.2