CyranoB commited on
Commit
127a815
1 Parent(s): 6fcf655
Files changed (2) hide show
  1. app.py +20 -15
  2. requirements.txt +2 -4
app.py CHANGED
@@ -20,11 +20,11 @@ if not torch.cuda.is_available():
20
  DESCRIPTION += "\n<p>Running on CPU 🥶</p>"
21
 
22
  MAX_SEED = np.iinfo(np.int32).max
23
- CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") != "0"
24
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1536"))
25
  USE_TORCH_COMPILE = False
26
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
27
- PREVIEW_IMAGES = True
28
 
29
  dtype = torch.bfloat16
30
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -47,10 +47,12 @@ if torch.cuda.is_available():
47
  previewer = Previewer()
48
  previewer_state_dict = torch.load("previewer/previewer_v1_100k.pt", map_location=torch.device('cpu'))["state_dict"]
49
  previewer.load_state_dict(previewer_state_dict)
50
- def callback_prior(i, t, latents):
 
51
  output = previewer(latents)
52
  output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).float().cpu().numpy())
53
- return output
 
54
  callback_steps = 1
55
  else:
56
  previewer = None
@@ -62,6 +64,7 @@ else:
62
 
63
 
64
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
 
65
  if randomize_seed:
66
  seed = random.randint(0, MAX_SEED)
67
  return seed
@@ -82,7 +85,8 @@ def generate(
82
  num_images_per_prompt: int = 2,
83
  profile: gr.OAuthProfile | None = None,
84
  ) -> PIL.Image.Image:
85
- previewer.eval().requires_grad_(False).to(device).to(dtype)
 
86
  prior_pipeline.to(device)
87
  decoder_pipeline.to(device)
88
 
@@ -98,10 +102,9 @@ def generate(
98
  guidance_scale=prior_guidance_scale,
99
  num_images_per_prompt=num_images_per_prompt,
100
  generator=generator,
101
- callback=callback_prior,
102
- callback_steps=callback_steps
103
  )
104
-
105
  if PREVIEW_IMAGES:
106
  for _ in range(len(DEFAULT_STAGE_C_TIMESTEPS)):
107
  r = next(prior_output)
@@ -119,7 +122,7 @@ def generate(
119
  generator=generator,
120
  output_type="pil",
121
  ).images
122
-
123
  #Save images
124
  for image in decoder_output:
125
  user_history.save_image(
@@ -137,15 +140,17 @@ def generate(
137
  "num_images_per_prompt": num_images_per_prompt,
138
  },
139
  )
140
-
141
  yield decoder_output[0]
142
 
143
 
144
  examples = [
145
- "An astronaut riding a green horse",
146
- "A mecha robot in a favela by Tarsila do Amaral",
147
- "The sprirt of a Tamagotchi wandering in the city of Los Angeles",
148
- "A delicious feijoada ramen dish"
 
 
149
  ]
150
 
151
  with gr.Blocks() as demo:
@@ -277,4 +282,4 @@ with gr.Blocks(css="style.css") as demo_with_history:
277
  user_history.render()
278
 
279
  if __name__ == "__main__":
280
- demo_with_history.queue(max_size=20).launch()
 
20
  DESCRIPTION += "\n<p>Running on CPU 🥶</p>"
21
 
22
  MAX_SEED = np.iinfo(np.int32).max
23
+ CACHE_EXAMPLES = False #torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") != "0"
24
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1536"))
25
  USE_TORCH_COMPILE = False
26
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
27
+ PREVIEW_IMAGES = False
28
 
29
  dtype = torch.bfloat16
30
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
47
  previewer = Previewer()
48
  previewer_state_dict = torch.load("previewer/previewer_v1_100k.pt", map_location=torch.device('cpu'))["state_dict"]
49
  previewer.load_state_dict(previewer_state_dict)
50
+ def callback_prior(pipeline, step_index, t, callback_kwargs):
51
+ latents = callback_kwargs["latents"]
52
  output = previewer(latents)
53
  output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).float().cpu().numpy())
54
+ callback_kwargs["preview_output"] = output
55
+ return callback_kwargs
56
  callback_steps = 1
57
  else:
58
  previewer = None
 
64
 
65
 
66
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
67
+ print("randomizing seed")
68
  if randomize_seed:
69
  seed = random.randint(0, MAX_SEED)
70
  return seed
 
85
  num_images_per_prompt: int = 2,
86
  profile: gr.OAuthProfile | None = None,
87
  ) -> PIL.Image.Image:
88
+
89
+ #previewer.eval().requires_grad_(False).to(device).to(dtype)
90
  prior_pipeline.to(device)
91
  decoder_pipeline.to(device)
92
 
 
102
  guidance_scale=prior_guidance_scale,
103
  num_images_per_prompt=num_images_per_prompt,
104
  generator=generator,
105
+ #callback_on_step_end=callback_prior,
106
+ #callback_on_step_end_tensor_inputs=['latents']
107
  )
 
108
  if PREVIEW_IMAGES:
109
  for _ in range(len(DEFAULT_STAGE_C_TIMESTEPS)):
110
  r = next(prior_output)
 
122
  generator=generator,
123
  output_type="pil",
124
  ).images
125
+ print(decoder_output)
126
  #Save images
127
  for image in decoder_output:
128
  user_history.save_image(
 
140
  "num_images_per_prompt": num_images_per_prompt,
141
  },
142
  )
143
+
144
  yield decoder_output[0]
145
 
146
 
147
  examples = [
148
+ "A futuristic cityscape at sunset",
149
+ "pair of shoes made of dried fruit skins, 3d render, bright colours, clean composition, beautiful artwork, logo",
150
+ "post-apocalyptic wasteland, the most delicate beautiful flower with green leaves growing from dust and rubble, vibrant colours, cinematic",
151
+ "Mixed media artwork, Emotional cyborg girl, Elegant dress, Skin lesions as a storytelling element, In the style of surrealist expressionism, muted color scheme, dreamlike atmosphere, abstract and distorted forms",
152
+ "rendering, side shot, falf-strange body with complex system equipment with hyper detail robot, gaze, sci-fi, gloomy environment, foggy with light shader, cyan and yellow illuminations, dramatic lighting, RTX shader, hyper detail texture with reflection, HDRI, cyborg, grunge, bolt, UHD",
153
+ "vintage Japanese postcard, in the style of Kentaro Miura, featuring a black cat holding a vinyl record in its paws, with vintage colors including light beige and red tones on a white background, very detailed artwork."
154
  ]
155
 
156
  with gr.Blocks() as demo:
 
282
  user_history.render()
283
 
284
  if __name__ == "__main__":
285
+ demo_with_history.queue(max_size=20).launch()
requirements.txt CHANGED
@@ -1,5 +1,3 @@
1
- git+https://github.com/kashif/diffusers.git@diffusers-yield-callback
2
- https://gradio-builds.s3.amazonaws.com/aabb08191a7d94d2a1e9ff87b0d3c3987cd519c5/gradio-4.18.0-py3-none-any.whl
3
  accelerate
4
- safetensors
5
- transformers
 
1
+ diffusers
 
2
  accelerate
3
+ transformers