Robert001 commited on
Commit
316bfe4
1 Parent(s): 044ef6e

first commit

Browse files
Files changed (3) hide show
  1. lib/autoencoder.py +1 -1
  2. lib/encoder.py +1 -1
  3. lib/utils.py +0 -117
lib/autoencoder.py CHANGED
@@ -16,7 +16,7 @@ from contextlib import contextmanager
16
  from lib.model import Encoder, Decoder
17
  from lib.distributions import DiagonalGaussianDistribution
18
 
19
- from lib.util import instantiate_from_config
20
 
21
  class AutoencoderKL(pl.LightningModule):
22
  def __init__(self,
 
16
  from lib.model import Encoder, Decoder
17
  from lib.distributions import DiagonalGaussianDistribution
18
 
19
+ from utils import instantiate_from_config
20
 
21
  class AutoencoderKL(pl.LightningModule):
22
  def __init__(self,
lib/encoder.py CHANGED
@@ -15,7 +15,7 @@ from torch.utils.checkpoint import checkpoint
15
  from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
16
 
17
  import open_clip
18
- from lib.util import default, count_params
19
 
20
 
21
  class AbstractEncoder(nn.Module):
 
15
  from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
16
 
17
  import open_clip
18
+ from utils import default, count_params
19
 
20
 
21
  class AbstractEncoder(nn.Module):
lib/utils.py DELETED
@@ -1,117 +0,0 @@
1
- '''
2
- * Copyright (c) 2023 Salesforce, Inc.
3
- * All rights reserved.
4
- * SPDX-License-Identifier: Apache License 2.0
5
- * For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/
6
- * By Can Qin
7
- * Modified from ControlNet repo: https://github.com/lllyasviel/ControlNet
8
- * Copyright (c) 2023 Lvmin Zhang and Maneesh Agrawala
9
- '''
10
-
11
- import os
12
- import torch
13
- from omegaconf import OmegaConf
14
- import importlib
15
- import numpy as np
16
-
17
-
18
- from inspect import isfunction
19
- from PIL import Image, ImageDraw, ImageFont
20
-
21
-
22
- def log_txt_as_img(wh, xc, size=10):
23
- # wh a tuple of (width, height)
24
- # xc a list of captions to plot
25
- b = len(xc)
26
- txts = list()
27
- for bi in range(b):
28
- txt = Image.new("RGB", wh, color="white")
29
- draw = ImageDraw.Draw(txt)
30
- font = ImageFont.truetype('font/DejaVuSans.ttf', size=size)
31
- nc = int(40 * (wh[0] / 256))
32
- lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
33
-
34
- try:
35
- draw.text((0, 0), lines, fill="black", font=font)
36
- except UnicodeEncodeError:
37
- print("Cant encode string for logging. Skipping.")
38
-
39
- txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
40
- txts.append(txt)
41
- txts = np.stack(txts)
42
- txts = torch.tensor(txts)
43
- return txts
44
-
45
-
46
- def ismap(x):
47
- if not isinstance(x, torch.Tensor):
48
- return False
49
- return (len(x.shape) == 4) and (x.shape[1] > 3)
50
-
51
-
52
- def isimage(x):
53
- if not isinstance(x,torch.Tensor):
54
- return False
55
- return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
56
-
57
-
58
- def exists(x):
59
- return x is not None
60
-
61
-
62
- def default(val, d):
63
- if exists(val):
64
- return val
65
- return d() if isfunction(d) else d
66
-
67
-
68
- def mean_flat(tensor):
69
- """
70
- https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
71
- Take the mean over all non-batch dimensions.
72
- """
73
- return tensor.mean(dim=list(range(1, len(tensor.shape))))
74
-
75
- def count_params(model, verbose=False):
76
- total_params = sum(p.numel() for p in model.parameters())
77
- if verbose:
78
- print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
79
- return total_params
80
-
81
-
82
- def get_state_dict(d):
83
- return d.get('state_dict', d)
84
-
85
-
86
- def load_state_dict(ckpt_path, location='cpu'):
87
- _, extension = os.path.splitext(ckpt_path)
88
- if extension.lower() == ".safetensors":
89
- import safetensors.torch
90
- state_dict = safetensors.torch.load_file(ckpt_path, device=location)
91
- else:
92
- state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
93
- state_dict = get_state_dict(state_dict)
94
- print(f'Loaded state_dict from [{ckpt_path}]')
95
- return state_dict
96
-
97
- def get_obj_from_str(string, reload=False):
98
- module, cls = string.rsplit(".", 1)
99
- if reload:
100
- module_imp = importlib.import_module(module)
101
- importlib.reload(module_imp)
102
- return getattr(importlib.import_module(module, package=None), cls)
103
-
104
- def instantiate_from_config(config):
105
- if not "target" in config:
106
- if config == '__is_first_stage__':
107
- return None
108
- elif config == "__is_unconditional__":
109
- return None
110
- raise KeyError("Expected key `target` to instantiate.")
111
- return get_obj_from_str(config["target"])(**config.get("params", dict()))
112
-
113
- def create_model(config_path):
114
- config = OmegaConf.load(config_path)
115
- model = instantiate_from_config(config.model).cpu()
116
- print(f'Loaded model config from [{config_path}]')
117
- return model