lazerkat commited on
Commit
4ad9a53
·
verified ·
1 Parent(s): 3192df2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -5
app.py CHANGED
@@ -146,14 +146,14 @@ class Diffusion:
146
  self.alpha_bars = torch.cumprod(self.alphas, dim=0)
147
 
148
  @torch.no_grad()
149
- def sample(self, model, text_tokens, image_size=64, steps=None):
150
  model.eval()
151
  if steps is None:
152
  steps = self.timesteps
153
 
154
  x = torch.randn(1, 3, image_size, image_size).to(self.device)
155
 
156
- for t in reversed(range(steps)):
157
  t_batch = torch.full((x.shape[0],), t, device=self.device, dtype=torch.long)
158
  predicted_noise = model(x, t_batch, text_tokens)
159
 
@@ -169,6 +169,11 @@ class Diffusion:
169
  x = (1 / torch.sqrt(alpha)) * (x - ((1 - alpha) / torch.sqrt(1 - alpha_bar)) * predicted_noise)
170
  x = x + torch.sqrt(beta) * noise
171
 
 
 
 
 
 
172
  model.train()
173
  return x
174
 
@@ -232,18 +237,31 @@ def tokenize_text(text, max_len=20):
232
  indices.append(0) # PAD token
233
  return torch.tensor(indices).unsqueeze(0).to(device)
234
 
235
- # Generate image
236
- def generate_image(prompt):
237
  global model, device, vocab_data
238
 
239
  if model is None or vocab_data is None:
240
  return None
241
 
 
 
242
  diffusion = Diffusion(timesteps=500, device=device) # Use 500 timesteps like training
243
 
 
 
 
244
  with torch.no_grad():
245
  text_tokens = tokenize_text(prompt)
246
- generated = diffusion.sample(model, text_tokens, image_size=64, steps=500)
 
 
 
 
 
 
 
 
247
 
248
  # Convert to image
249
  image = generated.cpu().squeeze(0)
 
146
  self.alpha_bars = torch.cumprod(self.alphas, dim=0)
147
 
148
  @torch.no_grad()
149
+ def sample(self, model, text_tokens, image_size=64, steps=None, progress_callback=None):
150
  model.eval()
151
  if steps is None:
152
  steps = self.timesteps
153
 
154
  x = torch.randn(1, 3, image_size, image_size).to(self.device)
155
 
156
+ for i, t in enumerate(reversed(range(steps))):
157
  t_batch = torch.full((x.shape[0],), t, device=self.device, dtype=torch.long)
158
  predicted_noise = model(x, t_batch, text_tokens)
159
 
 
169
  x = (1 / torch.sqrt(alpha)) * (x - ((1 - alpha) / torch.sqrt(1 - alpha_bar)) * predicted_noise)
170
  x = x + torch.sqrt(beta) * noise
171
 
172
+ # Report progress
173
+ if progress_callback is not None:
174
+ progress = (i + 1) / steps
175
+ progress_callback(progress)
176
+
177
  model.train()
178
  return x
179
 
 
237
  indices.append(0) # PAD token
238
  return torch.tensor(indices).unsqueeze(0).to(device)
239
 
240
+ # Generate image with progress
241
+ def generate_image(prompt, progress=gr.Progress()):
242
  global model, device, vocab_data
243
 
244
  if model is None or vocab_data is None:
245
  return None
246
 
247
+ progress(0, desc="Starting generation...")
248
+
249
  diffusion = Diffusion(timesteps=500, device=device) # Use 500 timesteps like training
250
 
251
+ def update_progress(pct):
252
+ progress(pct, desc=f"Generating... {pct*100:.1f}%")
253
+
254
  with torch.no_grad():
255
  text_tokens = tokenize_text(prompt)
256
+ generated = diffusion.sample(
257
+ model,
258
+ text_tokens,
259
+ image_size=64,
260
+ steps=500,
261
+ progress_callback=update_progress
262
+ )
263
+
264
+ progress(1.0, desc="Converting to image...")
265
 
266
  # Convert to image
267
  image = generated.cpu().squeeze(0)