bluestarburst commited on
Commit
00e8857
1 Parent(s): 9a6a590

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. handler.py +8 -3
  2. train.py +14 -9
handler.py CHANGED
@@ -10,6 +10,7 @@ import os
10
  from diffusers.utils.import_utils import is_xformers_available
11
  from typing import Any
12
  import torch
 
13
  import torchvision
14
  import numpy as np
15
  from einops import rearrange
@@ -101,10 +102,14 @@ class EndpointHandler():
101
  x = (x * 255).numpy().astype(np.uint8)
102
  outputs.append(x)
103
 
104
- # imageio.mimsave(path, outputs, fps=fps)
 
105
 
106
- # return a gif file as bytes
107
- return outputs
 
 
 
108
 
109
 
110
  # This is the entry point for the serverless function.
 
10
  from diffusers.utils.import_utils import is_xformers_available
11
  from typing import Any
12
  import torch
13
+ import imageio
14
  import torchvision
15
  import numpy as np
16
  from einops import rearrange
 
102
  x = (x * 255).numpy().astype(np.uint8)
103
  outputs.append(x)
104
 
105
+ path = "output.gif"
106
+ imageio.mimsave(path, outputs, fps=fps)
107
 
108
+ # open the file as binary and read the data
109
+ with open(path, mode="rb") as file:
110
+ fileContent = file.read()
111
+ # return json response with binary data
112
+ return fileContent
113
 
114
 
115
  # This is the entry point for the serverless function.
train.py CHANGED
@@ -321,6 +321,7 @@ def main(
321
  # Only show the progress bar once on each machine.
322
  progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)
323
  progress_bar.set_description("Steps")
 
324
 
325
  for epoch in range(first_epoch, num_train_epochs):
326
  unet.train()
@@ -363,28 +364,32 @@ def main(
363
  else:
364
  raise ValueError(f"Unknown prediction type {noise_scheduler.prediction_type}")
365
 
 
366
  # Predict the noise residual and compute loss
367
  model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
 
368
  loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
369
 
370
  # Gather the losses across all processes for logging (if we use distributed training).
371
  avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean()
372
  train_loss += avg_loss.item() / gradient_accumulation_steps
373
 
374
- for name, module in unet.named_modules():
375
- if "motion_modules" in name and (train_whole_module or name.endswith(tuple(trainable_modules))):
376
- for params in module.parameters():
377
- params.requires_grad = True
378
 
379
  # Backpropagate
380
- accelerator.backward(loss)
 
 
 
 
381
  if accelerator.sync_gradients:
382
  accelerator.clip_grad_norm_(unet.parameters(), max_grad_norm)
383
 
384
- # for param in unet.parameters():
385
- # print(param.grad)
386
-
387
-
 
388
 
389
  optimizer.step()
390
  lr_scheduler.step()
 
321
  # Only show the progress bar once on each machine.
322
  progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)
323
  progress_bar.set_description("Steps")
324
+ optimizer.zero_grad()
325
 
326
  for epoch in range(first_epoch, num_train_epochs):
327
  unet.train()
 
364
  else:
365
  raise ValueError(f"Unknown prediction type {noise_scheduler.prediction_type}")
366
 
367
+
368
  # Predict the noise residual and compute loss
369
  model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
370
+ print("Model Output:", model_pred)
371
  loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
372
 
373
  # Gather the losses across all processes for logging (if we use distributed training).
374
  avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean()
375
  train_loss += avg_loss.item() / gradient_accumulation_steps
376
 
377
+ print("Loss:", loss)
 
 
 
378
 
379
  # Backpropagate
380
+ # accelerator.backward(loss)
381
+
382
+ with accelerator.scaler.scale_loss(loss) as scaled_loss:
383
+ scaled_loss.backward()
384
+
385
  if accelerator.sync_gradients:
386
  accelerator.clip_grad_norm_(unet.parameters(), max_grad_norm)
387
 
388
+ print("grad: ")
389
+ for param in unet.parameters():
390
+ if param.grad is not None:
391
+ print(param.grad)
392
+ break
393
 
394
  optimizer.step()
395
  lr_scheduler.step()