Spaces:
Running
Running
added few ldm files
Browse files- ldm/analysis_utils.py +27 -0
- ldm/loading_utils.py +38 -0
- ldm/lr_scheduler.py +120 -0
- ldm/plotting_utils.py +200 -0
- ldm/util.py +243 -0
ldm/analysis_utils.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
|
4 |
+
EPS=1e-10
|
5 |
+
|
6 |
+
def get_CosineDistance_matrix(features):
|
7 |
+
if features.dim() >2:
|
8 |
+
features = features.reshape(features.shape[0], -1)
|
9 |
+
|
10 |
+
features_norm = features / (EPS + features.norm(dim=1)[:, None])
|
11 |
+
ans = torch.mm(features_norm, features_norm.transpose(0,1))
|
12 |
+
|
13 |
+
# We want distance, not similarity.
|
14 |
+
ans = torch.add(-ans, 1.)
|
15 |
+
|
16 |
+
return ans
|
17 |
+
|
18 |
+
def aggregatefrom_specimen_to_species(sorted_class_names_according_to_class_indx, specimen_distance_matrix, z_size, channels):
|
19 |
+
unique_sorted_class_names_according_to_class_indx = sorted(set(sorted_class_names_according_to_class_indx))
|
20 |
+
|
21 |
+
# species_dist_matrix = torch.zeros(len(unique_sorted_class_names_according_to_class_indx), 256, 16, 16)
|
22 |
+
species_dist_matrix = torch.zeros(len(unique_sorted_class_names_according_to_class_indx), channels, z_size, z_size)
|
23 |
+
for indx_i, i in enumerate(unique_sorted_class_names_according_to_class_indx):
|
24 |
+
class_i_indices = [idx for idx, element in enumerate(sorted_class_names_according_to_class_indx) if element == i]
|
25 |
+
species_dist_matrix[indx_i] = torch.mean(specimen_distance_matrix[class_i_indices,:], dim=0, keepdim=True)
|
26 |
+
|
27 |
+
return species_dist_matrix
|
ldm/loading_utils.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#based on https://github.com/CompVis/taming-transformers
|
2 |
+
|
3 |
+
import yaml
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
import torch
|
6 |
+
from ldm.util import instantiate_from_config
|
7 |
+
|
8 |
+
######### loaders
|
9 |
+
|
10 |
+
def load_config(config_path, display=False):
|
11 |
+
config = OmegaConf.load(config_path)
|
12 |
+
if display:
|
13 |
+
print(yaml.dump(OmegaConf.to_container(config)))
|
14 |
+
return config
|
15 |
+
|
16 |
+
def load_model_from_config(config, ckpt):
|
17 |
+
print(f"Loading model from {ckpt}")
|
18 |
+
pl_sd = torch.load(ckpt)#, map_location="cpu")
|
19 |
+
sd = pl_sd["state_dict"]
|
20 |
+
model = instantiate_from_config(config.model)
|
21 |
+
m, u = model.load_state_dict(sd, strict=False)
|
22 |
+
model.cuda()
|
23 |
+
model.eval()
|
24 |
+
return model
|
25 |
+
|
26 |
+
def load_model(config_path, ckpt_path=None):
|
27 |
+
# def load_model(config_path, ckpt_path=None, cuda=False, model_type=VQModel):
|
28 |
+
# breakpoint()
|
29 |
+
# model = model_type(**config.model.params)
|
30 |
+
# if ckpt_path is not None:
|
31 |
+
# sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
32 |
+
# missing, unexpected = model.load_state_dict(sd, strict=True)
|
33 |
+
# if cuda:
|
34 |
+
# model = model.cuda()
|
35 |
+
|
36 |
+
config = OmegaConf.load(config_path)
|
37 |
+
model = load_model_from_config(config, ckpt_path)
|
38 |
+
return model
|
ldm/lr_scheduler.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class LambdaWarmUpCosineScheduler:
|
5 |
+
"""
|
6 |
+
note: use with a base_lr of 1.0
|
7 |
+
"""
|
8 |
+
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
|
9 |
+
self.lr_warm_up_steps = warm_up_steps
|
10 |
+
self.lr_start = lr_start
|
11 |
+
self.lr_min = lr_min
|
12 |
+
self.lr_max = lr_max
|
13 |
+
self.lr_max_decay_steps = max_decay_steps
|
14 |
+
self.last_lr = 0.
|
15 |
+
self.verbosity_interval = verbosity_interval
|
16 |
+
|
17 |
+
def schedule(self, n, **kwargs):
|
18 |
+
if self.verbosity_interval > 0:
|
19 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
20 |
+
if n < self.lr_warm_up_steps:
|
21 |
+
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
|
22 |
+
self.last_lr = lr
|
23 |
+
return lr
|
24 |
+
else:
|
25 |
+
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
|
26 |
+
t = min(t, 1.0)
|
27 |
+
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
28 |
+
1 + np.cos(t * np.pi))
|
29 |
+
self.last_lr = lr
|
30 |
+
return lr
|
31 |
+
|
32 |
+
def __call__(self, n, **kwargs):
|
33 |
+
return self.schedule(n,**kwargs)
|
34 |
+
|
35 |
+
|
36 |
+
class LambdaWarmUpCosineScheduler2:
|
37 |
+
"""
|
38 |
+
supports repeated iterations, configurable via lists
|
39 |
+
note: use with a base_lr of 1.0.
|
40 |
+
"""
|
41 |
+
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0, gamma=0.99, step_size=1000):
|
42 |
+
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
|
43 |
+
self.lr_warm_up_steps = warm_up_steps
|
44 |
+
self.f_start = f_start
|
45 |
+
self.f_min = f_min
|
46 |
+
self.f_max = f_max
|
47 |
+
self.gamma = gamma
|
48 |
+
self.step_size = step_size
|
49 |
+
self.cycle_lengths = cycle_lengths
|
50 |
+
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
51 |
+
self.last_f = 0.
|
52 |
+
self.verbosity_interval = verbosity_interval
|
53 |
+
|
54 |
+
def find_in_interval(self, n):
|
55 |
+
interval = 0
|
56 |
+
for cl in self.cum_cycles[1:]:
|
57 |
+
if n <= cl:
|
58 |
+
return interval
|
59 |
+
interval += 1
|
60 |
+
|
61 |
+
def schedule(self, n, **kwargs):
|
62 |
+
cycle = self.find_in_interval(n)
|
63 |
+
n = n - self.cum_cycles[cycle]
|
64 |
+
if self.verbosity_interval > 0:
|
65 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
66 |
+
f"current cycle {cycle}")
|
67 |
+
if n < self.lr_warm_up_steps[cycle]:
|
68 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
69 |
+
self.last_f = f
|
70 |
+
return f
|
71 |
+
else:
|
72 |
+
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
|
73 |
+
t = min(t, 1.0)
|
74 |
+
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
75 |
+
1 + np.cos(t * np.pi))
|
76 |
+
self.last_f = f
|
77 |
+
return f
|
78 |
+
|
79 |
+
def __call__(self, n, **kwargs):
|
80 |
+
return self.schedule(n, **kwargs)
|
81 |
+
|
82 |
+
|
83 |
+
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
84 |
+
|
85 |
+
def schedule(self, n, **kwargs):
|
86 |
+
cycle = self.find_in_interval(n)
|
87 |
+
n = n - self.cum_cycles[cycle]
|
88 |
+
if self.verbosity_interval > 0:
|
89 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
90 |
+
f"current cycle {cycle}")
|
91 |
+
|
92 |
+
if n < self.lr_warm_up_steps[cycle]:
|
93 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
94 |
+
self.last_f = f
|
95 |
+
return f
|
96 |
+
else:
|
97 |
+
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
|
98 |
+
self.last_f = f
|
99 |
+
return f
|
100 |
+
|
101 |
+
class LambdaLinearScheduler_step(LambdaWarmUpCosineScheduler2):
|
102 |
+
|
103 |
+
def schedule(self, n, **kwargs):
|
104 |
+
cycle = self.find_in_interval(n)
|
105 |
+
n = n - self.cum_cycles[cycle]
|
106 |
+
if self.verbosity_interval > 0:
|
107 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
108 |
+
f"current cycle {cycle}")
|
109 |
+
|
110 |
+
if n < self.lr_warm_up_steps[cycle]:
|
111 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
112 |
+
self.last_f = f
|
113 |
+
return f
|
114 |
+
else:
|
115 |
+
f = self.gamma ** ((n-self.lr_warm_up_steps[cycle]) // self.step_size)
|
116 |
+
# f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
|
117 |
+
self.last_f = f
|
118 |
+
return f
|
119 |
+
|
120 |
+
# class LambdaCustomScheduler:
|
ldm/plotting_utils.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#based on https://github.com/CompVis/taming-transformers
|
3 |
+
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import seaborn as sns
|
6 |
+
import os
|
7 |
+
from pathlib import Path
|
8 |
+
import torchvision
|
9 |
+
import torch
|
10 |
+
import numpy as np
|
11 |
+
from PIL import Image
|
12 |
+
import json
|
13 |
+
import csv
|
14 |
+
import pandas as pd
|
15 |
+
|
16 |
+
from sklearn.metrics import ConfusionMatrixDisplay
|
17 |
+
|
18 |
+
|
19 |
+
def dump_to_json(dict, ckpt_path, name='results', get_fig_path=True):
|
20 |
+
|
21 |
+
if get_fig_path:
|
22 |
+
root = get_fig_pth(ckpt_path)
|
23 |
+
else:
|
24 |
+
root = ckpt_path
|
25 |
+
if not os.path.exists(root):
|
26 |
+
os.mkdir(root)
|
27 |
+
|
28 |
+
with open(os.path.join(root, name+".json"), "w") as outfile:
|
29 |
+
json.dump(dict, outfile)
|
30 |
+
|
31 |
+
|
32 |
+
def save_to_cvs(ckpt_path, postfix, file_name, list_of_created_sequence):
|
33 |
+
if ckpt_path is not None:
|
34 |
+
root = get_fig_pth(ckpt_path, postfix=postfix)
|
35 |
+
else:
|
36 |
+
root = postfix
|
37 |
+
|
38 |
+
file = open(os.path.join(root, file_name), 'w')
|
39 |
+
with file:
|
40 |
+
write = csv.writer(file)
|
41 |
+
write.writerows(list_of_created_sequence)
|
42 |
+
|
43 |
+
def save_to_txt(arr, ckpt_path, name='results'):
|
44 |
+
root = get_fig_pth(ckpt_path)
|
45 |
+
with open(os.path.join(root, name+".txt"), "w") as outfile:
|
46 |
+
outfile.write(str(arr))
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
def save_image_grid(torch_images, ckpt_path=None, subfolder=None, postfix="", nrow=10):
|
51 |
+
if ckpt_path is not None:
|
52 |
+
root = get_fig_pth(ckpt_path, postfix=subfolder)
|
53 |
+
else:
|
54 |
+
root = subfolder
|
55 |
+
|
56 |
+
grid = torchvision.utils.make_grid(torch_images, nrow=nrow)
|
57 |
+
grid = torch.clamp(grid, -1., 1.)
|
58 |
+
|
59 |
+
grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w
|
60 |
+
grid = grid.transpose(0,1).transpose(1,2).squeeze(-1)
|
61 |
+
grid = grid.cpu().numpy()
|
62 |
+
grid = (grid*255).astype(np.uint8)
|
63 |
+
filename = "code_changes_"+postfix+".png"
|
64 |
+
path = os.path.join(root, filename)
|
65 |
+
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
66 |
+
Image.fromarray(grid).save(path, bbox_inches='tight')
|
67 |
+
|
68 |
+
|
69 |
+
def unprocess_image(torch_image):
|
70 |
+
torch_image = torch.clamp(torch_image, -1., 1.)
|
71 |
+
|
72 |
+
torch_image = (torch_image+1.0)/2.0 # -1,1 -> 0,1; c,h,w
|
73 |
+
torch_image = torch_image.transpose(0,1).transpose(1,2).squeeze(-1)
|
74 |
+
torch_image = torch_image.cpu().numpy()
|
75 |
+
torch_image = (torch_image*255).astype(np.uint8)
|
76 |
+
return torch_image
|
77 |
+
|
78 |
+
def save_image(torch_image, image_name, ckpt_path=None, subfolder=None):
|
79 |
+
if ckpt_path is not None:
|
80 |
+
root = get_fig_pth(ckpt_path, postfix=subfolder)
|
81 |
+
else:
|
82 |
+
root = subfolder
|
83 |
+
|
84 |
+
torch_image = unprocess_image(torch_image)
|
85 |
+
|
86 |
+
filename = image_name+".png"
|
87 |
+
path = os.path.join(root, filename)
|
88 |
+
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
89 |
+
fig = plt.figure()
|
90 |
+
plt.imshow(torch_image[0].squeeze())
|
91 |
+
fig.savefig(path,bbox_inches='tight',dpi=300)
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
def get_fig_pth(ckpt_path, postfix=None):
|
96 |
+
figs_postfix = 'figs'
|
97 |
+
postfix = os.path.join(figs_postfix, postfix) if postfix is not None else figs_postfix
|
98 |
+
parent_path = Path(ckpt_path).parent.parent.absolute()
|
99 |
+
fig_path = Path(os.path.join(parent_path, postfix))
|
100 |
+
os.makedirs(fig_path, exist_ok=True)
|
101 |
+
return fig_path
|
102 |
+
|
103 |
+
def plot_heatmap(heatmap, ckpt_path=None, title='default', postfix=None):
|
104 |
+
if ckpt_path is not None:
|
105 |
+
path = get_fig_pth(ckpt_path, postfix=postfix)
|
106 |
+
else:
|
107 |
+
path = postfix
|
108 |
+
|
109 |
+
# show
|
110 |
+
fig = plt.figure()
|
111 |
+
ax = plt.imshow(heatmap, cmap='hot', interpolation='nearest')
|
112 |
+
plt.tick_params(left=False, bottom=False)
|
113 |
+
# cbar = ax.collections[0].colorbar
|
114 |
+
cbar = plt.colorbar(ax)
|
115 |
+
cbar.ax.tick_params(labelsize=15)
|
116 |
+
plt.axis('off')
|
117 |
+
plt.show()
|
118 |
+
fig.savefig(os.path.join(path, title+ " heat_map.png"),bbox_inches='tight',dpi=300)
|
119 |
+
pd.DataFrame(heatmap.numpy()).to_csv(os.path.join(path, title+ " heat_map.csv"))
|
120 |
+
|
121 |
+
def plot_heatmap_at_path(heatmap, save_path, ckpt_path=None, title='default', postfix=None):
|
122 |
+
if ckpt_path is not None:
|
123 |
+
path = get_fig_pth(ckpt_path, postfix=postfix)
|
124 |
+
else:
|
125 |
+
path = postfix
|
126 |
+
|
127 |
+
# show
|
128 |
+
fig = plt.figure()
|
129 |
+
ax = plt.imshow(heatmap, cmap='hot', interpolation='nearest')
|
130 |
+
plt.tick_params(left=False, bottom=False)
|
131 |
+
# cbar = ax.collections[0].colorbar
|
132 |
+
cbar = plt.colorbar(ax)
|
133 |
+
cbar.ax.tick_params(labelsize=15)
|
134 |
+
plt.axis('off')
|
135 |
+
plt.show()
|
136 |
+
fig.savefig(os.path.join(save_path, title+ "_heat_map.png"),bbox_inches='tight',dpi=300)
|
137 |
+
pd.DataFrame(heatmap.numpy()).to_csv(os.path.join(save_path, title+ "_heat_map.csv"))
|
138 |
+
|
139 |
+
def plot_confusionmatrix(preds, classes, classnames, ckpt_path, postfix=None, title="", get_fig_path=True):
|
140 |
+
fig, ax = plt.subplots(figsize=(30,30))
|
141 |
+
preds_max = np.argmax(preds.cpu().numpy(), axis=-1)
|
142 |
+
disp = ConfusionMatrixDisplay.from_predictions(classes.cpu().numpy(), preds_max, display_labels=classnames, normalize='true', xticks_rotation='vertical', ax=ax)
|
143 |
+
disp.plot()
|
144 |
+
|
145 |
+
if get_fig_path:
|
146 |
+
fig_path = get_fig_pth(ckpt_path, postfix=postfix)
|
147 |
+
else:
|
148 |
+
fig_path = ckpt_path
|
149 |
+
if not os.path.exists(fig_path):
|
150 |
+
os.mkdir(fig_path)
|
151 |
+
|
152 |
+
print(fig_path)
|
153 |
+
fig.savefig(os.path.join(fig_path, title+ " heat_map.png"))
|
154 |
+
|
155 |
+
def plot_confusionmatrix_colormap(preds, classes, classnames, ckpt_path, postfix=None, title="", get_fig_path=True):
|
156 |
+
fig, ax = plt.subplots(figsize=(30,30))
|
157 |
+
preds_max = np.argmax(preds.cpu().numpy(), axis=-1)
|
158 |
+
class_labels = list(range(len(classnames)))
|
159 |
+
disp = ConfusionMatrixDisplay.from_predictions(classes.cpu().numpy(), preds_max, display_labels=class_labels, normalize='true', xticks_rotation='vertical', ax=ax, cmap='coolwarm')
|
160 |
+
disp.plot()
|
161 |
+
|
162 |
+
if get_fig_path:
|
163 |
+
fig_path = get_fig_pth(ckpt_path, postfix=postfix)
|
164 |
+
else:
|
165 |
+
fig_path = ckpt_path
|
166 |
+
if not os.path.exists(fig_path):
|
167 |
+
os.mkdir(fig_path)
|
168 |
+
|
169 |
+
print(fig_path)
|
170 |
+
fig.savefig(os.path.join(fig_path, title+ " heat_map_coolwarm.png"))
|
171 |
+
|
172 |
+
|
173 |
+
class Histogram_plotter:
|
174 |
+
def __init__(self, codes_per_phylolevel, n_phylolevels, n_embed,
|
175 |
+
converter,
|
176 |
+
indx_to_label,
|
177 |
+
ckpt_path, directory):
|
178 |
+
self.codes_per_phylolevel = codes_per_phylolevel
|
179 |
+
self.n_phylolevels = n_phylolevels
|
180 |
+
self.n_embed = n_embed
|
181 |
+
self.converter = converter
|
182 |
+
self.ckpt_path = ckpt_path
|
183 |
+
self.directory = directory
|
184 |
+
self.indx_to_label = indx_to_label
|
185 |
+
|
186 |
+
def plot_histograms(self, histograms, species_indx, is_nonattribute=False, prefix="species"):
|
187 |
+
fig, axs = plt.subplots(self.codes_per_phylolevel, self.n_phylolevels, figsize = (5*self.n_phylolevels,30))
|
188 |
+
for i, ax in enumerate(axs.reshape(-1)):
|
189 |
+
ax.hist(histograms[i], density=True, range=(0, self.n_embed-1), bins=self.n_embed)
|
190 |
+
|
191 |
+
if not is_nonattribute:
|
192 |
+
code_location, level = self.converter.get_code_reshaped_index(i)
|
193 |
+
ax.set_title("code "+ str(code_location) + "/level " +str(level))
|
194 |
+
else:
|
195 |
+
ax.set_title("code "+ str(i))
|
196 |
+
|
197 |
+
plt.show()
|
198 |
+
sub_dir = 'attribute' if not is_nonattribute else 'non_attribute'
|
199 |
+
fig.savefig(os.path.join(get_fig_pth(self.ckpt_path, postfix=self.directory+'/'+sub_dir), "{}_{}_{}_hostogram.png".format(prefix, species_indx, self.indx_to_label[species_indx])),bbox_inches='tight',dpi=300)
|
200 |
+
plt.close(fig)
|
ldm/util.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import hashlib
|
5 |
+
import requests
|
6 |
+
import numpy as np
|
7 |
+
from tqdm import tqdm
|
8 |
+
from collections import abc
|
9 |
+
from einops import rearrange
|
10 |
+
from functools import partial
|
11 |
+
|
12 |
+
import multiprocessing as mp
|
13 |
+
from threading import Thread
|
14 |
+
from queue import Queue
|
15 |
+
|
16 |
+
from inspect import isfunction
|
17 |
+
from PIL import Image, ImageDraw, ImageFont
|
18 |
+
|
19 |
+
URL_MAP = {
|
20 |
+
"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
|
21 |
+
}
|
22 |
+
|
23 |
+
CKPT_MAP = {
|
24 |
+
"vgg_lpips": "vgg.pth"
|
25 |
+
}
|
26 |
+
|
27 |
+
MD5_MAP = {
|
28 |
+
"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
|
29 |
+
}
|
30 |
+
|
31 |
+
def md5_hash(path):
|
32 |
+
with open(path, "rb") as f:
|
33 |
+
content = f.read()
|
34 |
+
return hashlib.md5(content).hexdigest()
|
35 |
+
|
36 |
+
def log_txt_as_img(wh, xc, size=10):
|
37 |
+
# wh a tuple of (width, height)
|
38 |
+
# xc a list of captions to plot
|
39 |
+
b = len(xc)
|
40 |
+
txts = list()
|
41 |
+
for bi in range(b):
|
42 |
+
txt = Image.new("RGB", wh, color="white")
|
43 |
+
draw = ImageDraw.Draw(txt)
|
44 |
+
font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
|
45 |
+
nc = int(40 * (wh[0] / 256))
|
46 |
+
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
|
47 |
+
|
48 |
+
try:
|
49 |
+
draw.text((0, 0), lines, fill="black", font=font)
|
50 |
+
except UnicodeEncodeError:
|
51 |
+
print("Cant encode string for logging. Skipping.")
|
52 |
+
|
53 |
+
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
54 |
+
txts.append(txt)
|
55 |
+
txts = np.stack(txts)
|
56 |
+
txts = torch.tensor(txts)
|
57 |
+
return txts
|
58 |
+
|
59 |
+
def download(url, local_path, chunk_size=1024):
|
60 |
+
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
61 |
+
with requests.get(url, stream=True) as r:
|
62 |
+
total_size = int(r.headers.get("content-length", 0))
|
63 |
+
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
64 |
+
with open(local_path, "wb") as f:
|
65 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
66 |
+
if data:
|
67 |
+
f.write(data)
|
68 |
+
pbar.update(chunk_size)
|
69 |
+
|
70 |
+
def get_ckpt_path(name, root, check=False):
|
71 |
+
assert name in URL_MAP
|
72 |
+
path = os.path.join(root, CKPT_MAP[name])
|
73 |
+
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
74 |
+
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
75 |
+
download(URL_MAP[name], path)
|
76 |
+
md5 = md5_hash(path)
|
77 |
+
assert md5 == MD5_MAP[name], md5
|
78 |
+
return path
|
79 |
+
|
80 |
+
|
81 |
+
def ismap(x):
|
82 |
+
if not isinstance(x, torch.Tensor):
|
83 |
+
return False
|
84 |
+
return (len(x.shape) == 4) and (x.shape[1] > 3)
|
85 |
+
|
86 |
+
|
87 |
+
def isimage(x):
|
88 |
+
if not isinstance(x, torch.Tensor):
|
89 |
+
return False
|
90 |
+
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
91 |
+
|
92 |
+
|
93 |
+
def exists(x):
|
94 |
+
return x is not None
|
95 |
+
|
96 |
+
|
97 |
+
def default(val, d):
|
98 |
+
if exists(val):
|
99 |
+
return val
|
100 |
+
return d() if isfunction(d) else d
|
101 |
+
|
102 |
+
|
103 |
+
def mean_flat(tensor):
|
104 |
+
"""
|
105 |
+
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
|
106 |
+
Take the mean over all non-batch dimensions.
|
107 |
+
"""
|
108 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
109 |
+
|
110 |
+
|
111 |
+
def count_params(model, verbose=False):
|
112 |
+
total_params = sum(p.numel() for p in model.parameters())
|
113 |
+
if verbose:
|
114 |
+
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
|
115 |
+
return total_params
|
116 |
+
|
117 |
+
|
118 |
+
def instantiate_from_config(config):
|
119 |
+
if not "target" in config:
|
120 |
+
if config == '__is_first_stage__':
|
121 |
+
return None
|
122 |
+
elif config == "__is_unconditional__":
|
123 |
+
return None
|
124 |
+
raise KeyError("Expected key `target` to instantiate.")
|
125 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
126 |
+
|
127 |
+
|
128 |
+
def get_obj_from_str(string, reload=False):
|
129 |
+
module, cls = string.rsplit(".", 1)
|
130 |
+
if reload:
|
131 |
+
module_imp = importlib.import_module(module)
|
132 |
+
importlib.reload(module_imp)
|
133 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
134 |
+
|
135 |
+
|
136 |
+
def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
|
137 |
+
# create dummy dataset instance
|
138 |
+
|
139 |
+
# run prefetching
|
140 |
+
if idx_to_fn:
|
141 |
+
res = func(data, worker_id=idx)
|
142 |
+
else:
|
143 |
+
res = func(data)
|
144 |
+
Q.put([idx, res])
|
145 |
+
Q.put("Done")
|
146 |
+
|
147 |
+
|
148 |
+
def parallel_data_prefetch(
|
149 |
+
func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
|
150 |
+
):
|
151 |
+
# if target_data_type not in ["ndarray", "list"]:
|
152 |
+
# raise ValueError(
|
153 |
+
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
|
154 |
+
# )
|
155 |
+
if isinstance(data, np.ndarray) and target_data_type == "list":
|
156 |
+
raise ValueError("list expected but function got ndarray.")
|
157 |
+
elif isinstance(data, abc.Iterable):
|
158 |
+
if isinstance(data, dict):
|
159 |
+
print(
|
160 |
+
f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
|
161 |
+
)
|
162 |
+
data = list(data.values())
|
163 |
+
if target_data_type == "ndarray":
|
164 |
+
data = np.asarray(data)
|
165 |
+
else:
|
166 |
+
data = list(data)
|
167 |
+
else:
|
168 |
+
raise TypeError(
|
169 |
+
f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
|
170 |
+
)
|
171 |
+
|
172 |
+
if cpu_intensive:
|
173 |
+
Q = mp.Queue(1000)
|
174 |
+
proc = mp.Process
|
175 |
+
else:
|
176 |
+
Q = Queue(1000)
|
177 |
+
proc = Thread
|
178 |
+
# spawn processes
|
179 |
+
if target_data_type == "ndarray":
|
180 |
+
arguments = [
|
181 |
+
[func, Q, part, i, use_worker_id]
|
182 |
+
for i, part in enumerate(np.array_split(data, n_proc))
|
183 |
+
]
|
184 |
+
else:
|
185 |
+
step = (
|
186 |
+
int(len(data) / n_proc + 1)
|
187 |
+
if len(data) % n_proc != 0
|
188 |
+
else int(len(data) / n_proc)
|
189 |
+
)
|
190 |
+
arguments = [
|
191 |
+
[func, Q, part, i, use_worker_id]
|
192 |
+
for i, part in enumerate(
|
193 |
+
[data[i: i + step] for i in range(0, len(data), step)]
|
194 |
+
)
|
195 |
+
]
|
196 |
+
processes = []
|
197 |
+
for i in range(n_proc):
|
198 |
+
p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
|
199 |
+
processes += [p]
|
200 |
+
|
201 |
+
# start processes
|
202 |
+
print(f"Start prefetching...")
|
203 |
+
import time
|
204 |
+
|
205 |
+
start = time.time()
|
206 |
+
gather_res = [[] for _ in range(n_proc)]
|
207 |
+
try:
|
208 |
+
for p in processes:
|
209 |
+
p.start()
|
210 |
+
|
211 |
+
k = 0
|
212 |
+
while k < n_proc:
|
213 |
+
# get result
|
214 |
+
res = Q.get()
|
215 |
+
if res == "Done":
|
216 |
+
k += 1
|
217 |
+
else:
|
218 |
+
gather_res[res[0]] = res[1]
|
219 |
+
|
220 |
+
except Exception as e:
|
221 |
+
print("Exception: ", e)
|
222 |
+
for p in processes:
|
223 |
+
p.terminate()
|
224 |
+
|
225 |
+
raise e
|
226 |
+
finally:
|
227 |
+
for p in processes:
|
228 |
+
p.join()
|
229 |
+
print(f"Prefetching complete. [{time.time() - start} sec.]")
|
230 |
+
|
231 |
+
if target_data_type == 'ndarray':
|
232 |
+
if not isinstance(gather_res[0], np.ndarray):
|
233 |
+
return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
|
234 |
+
|
235 |
+
# order outputs
|
236 |
+
return np.concatenate(gather_res, axis=0)
|
237 |
+
elif target_data_type == 'list':
|
238 |
+
out = []
|
239 |
+
for r in gather_res:
|
240 |
+
out.extend(r)
|
241 |
+
return out
|
242 |
+
else:
|
243 |
+
return gather_res
|