Robert001 commited on
Commit
e598301
1 Parent(s): 2193df9

first commit

Browse files
Files changed (4) hide show
  1. lib/attention.py +2 -2
  2. lib/ddpm_multi.py +2 -2
  3. lib/openaimodel.py +2 -2
  4. lib/util.py +2 -2
lib/attention.py CHANGED
@@ -16,7 +16,7 @@ from torch import nn, einsum
16
  from einops import rearrange, repeat
17
  from typing import Optional, Any
18
 
19
- from utils import checkpoint
20
 
21
  try:
22
  import xformers
@@ -351,4 +351,4 @@ class SpatialTransformer(nn.Module):
351
  x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
352
  if not self.use_linear:
353
  x = self.proj_out(x)
354
- return x + x_in
 
16
  from einops import rearrange, repeat
17
  from typing import Optional, Any
18
 
19
+ from ..utils import checkpoint
20
 
21
  try:
22
  import xformers
 
351
  x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
352
  if not self.use_linear:
353
  x = self.proj_out(x)
354
+ return x + x_in
lib/ddpm_multi.py CHANGED
@@ -30,7 +30,7 @@ from torchvision.utils import make_grid
30
  from pytorch_lightning.utilities.distributed import rank_zero_only
31
  from omegaconf import ListConfig
32
 
33
- from utils import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
34
  from lib.distributions import normal_kl, DiagonalGaussianDistribution
35
  from lib.autoencoder import IdentityFirstStage, AutoencoderKL
36
  from lib.util import make_beta_schedule, extract_into_tensor, noise_like
@@ -1798,4 +1798,4 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
1798
  def log_images(self, *args, **kwargs):
1799
  log = super().log_images(*args, **kwargs)
1800
  log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
1801
- return log
 
30
  from pytorch_lightning.utilities.distributed import rank_zero_only
31
  from omegaconf import ListConfig
32
 
33
+ from ..utils import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
34
  from lib.distributions import normal_kl, DiagonalGaussianDistribution
35
  from lib.autoencoder import IdentityFirstStage, AutoencoderKL
36
  from lib.util import make_beta_schedule, extract_into_tensor, noise_like
 
1798
  def log_images(self, *args, **kwargs):
1799
  log = super().log_images(*args, **kwargs)
1800
  log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
1801
+ return log
lib/openaimodel.py CHANGED
@@ -26,7 +26,7 @@ from lib.util import (
26
  timestep_embedding,
27
  )
28
  from attention import SpatialTransformer
29
- from utils import exists
30
 
31
 
32
  # dummy replace
@@ -793,4 +793,4 @@ class UNetModel(nn.Module):
793
  if self.predict_codebook_ids:
794
  return self.id_predictor(h)
795
  else:
796
- return self.out(h)
 
26
  timestep_embedding,
27
  )
28
  from attention import SpatialTransformer
29
+ from ..utils import exists
30
 
31
 
32
  # dummy replace
 
793
  if self.predict_codebook_ids:
794
  return self.id_predictor(h)
795
  else:
796
+ return self.out(h)
lib/util.py CHANGED
@@ -25,7 +25,7 @@ import torch.nn as nn
25
  import numpy as np
26
  from einops import repeat
27
 
28
- from utils import instantiate_from_config
29
 
30
 
31
  def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
@@ -277,4 +277,4 @@ class HybridConditioner(nn.Module):
277
  def noise_like(shape, device, repeat=False):
278
  repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
279
  noise = lambda: torch.randn(shape, device=device)
280
- return repeat_noise() if repeat else noise()
 
25
  import numpy as np
26
  from einops import repeat
27
 
28
+ from ..utils import instantiate_from_config
29
 
30
 
31
  def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
 
277
  def noise_like(shape, device, repeat=False):
278
  repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
279
  noise = lambda: torch.randn(shape, device=device)
280
+ return repeat_noise() if repeat else noise()