ysharma HF staff commited on
Commit
d0e6699
1 Parent(s): c3e5ed7
Files changed (1) hide show
  1. util.py +86 -0
util.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ import torch
4
+ import numpy as np
5
+
6
+ from inspect import isfunction
7
+ from PIL import Image, ImageDraw, ImageFont
8
+
9
+
10
+ def log_txt_as_img(wh, xc, size=10):
11
+ # wh a tuple of (width, height)
12
+ # xc a list of captions to plot
13
+ b = len(xc)
14
+ txts = list()
15
+ for bi in range(b):
16
+ txt = Image.new("RGB", wh, color="white")
17
+ draw = ImageDraw.Draw(txt)
18
+ font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
19
+ nc = int(40 * (wh[0] / 256))
20
+ lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
21
+
22
+ try:
23
+ draw.text((0, 0), lines, fill="black", font=font)
24
+ except UnicodeEncodeError:
25
+ print("Cant encode string for logging. Skipping.")
26
+
27
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
28
+ txts.append(txt)
29
+ txts = np.stack(txts)
30
+ txts = torch.tensor(txts)
31
+ return txts
32
+
33
+
34
+ def ismap(x):
35
+ if not isinstance(x, torch.Tensor):
36
+ return False
37
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
38
+
39
+
40
+ def isimage(x):
41
+ if not isinstance(x,torch.Tensor):
42
+ return False
43
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
44
+
45
+
46
+ def exists(x):
47
+ return x is not None
48
+
49
+
50
+ def default(val, d):
51
+ if exists(val):
52
+ return val
53
+ return d() if isfunction(d) else d
54
+
55
+
56
+ def mean_flat(tensor):
57
+ """
58
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
59
+ Take the mean over all non-batch dimensions.
60
+ """
61
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
62
+
63
+
64
+ def count_params(model, verbose=False):
65
+ total_params = sum(p.numel() for p in model.parameters())
66
+ if verbose:
67
+ print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
68
+ return total_params
69
+
70
+
71
+ def instantiate_from_config(config):
72
+ if not "target" in config:
73
+ if config == '__is_first_stage__':
74
+ return None
75
+ elif config == "__is_unconditional__":
76
+ return None
77
+ raise KeyError("Expected key `target` to instantiate.")
78
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
79
+
80
+
81
+ def get_obj_from_str(string, reload=False):
82
+ module, cls = string.rsplit(".", 1)
83
+ if reload:
84
+ module_imp = importlib.import_module(module)
85
+ importlib.reload(module_imp)
86
+ return getattr(importlib.import_module(module, package=None), cls)