|
|
import yaml |
|
|
import string |
|
|
import secrets |
|
|
import os |
|
|
|
|
|
import torch |
|
|
import wandb |
|
|
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint |
|
|
from torchdyn.core import NeuralODE |
|
|
|
|
|
import torch |
|
|
|
|
|
@torch.no_grad() |
|
|
def gather_local_starts(x0s, X0_pool, N, k=64): |
|
|
|
|
|
B, G = x0s.shape |
|
|
d2 = torch.cdist(x0s, X0_pool).pow(2) |
|
|
knn_idx = d2.topk(k=min(k, X0_pool.size(0)), largest=False).indices |
|
|
x0_clusters = [] |
|
|
for b in range(B): |
|
|
choices = knn_idx[b] |
|
|
pick = choices[torch.randperm(choices.numel(), device=choices.device)[:N]] |
|
|
x0_clusters.append(X0_pool[pick]) |
|
|
return torch.stack(x0_clusters, dim=0) |
|
|
|
|
|
@torch.no_grad() |
|
|
def make_aligned_clusters(ot_sampler, x0s, x1s, N, replace=True, k_local=128): |
|
|
|
|
|
device, dtype = x0s.device, x0s.dtype |
|
|
|
|
|
B, G = x0s.shape |
|
|
M = x1s.shape[0] |
|
|
|
|
|
x0_clusters = gather_local_starts(x0s, x0s, N, k=k_local).to(device=device, dtype=dtype) |
|
|
x1_clusters = torch.empty((B, N, G), device=device, dtype=dtype) |
|
|
idx1 = torch.empty((B, N), device=device, dtype=torch.long) |
|
|
|
|
|
|
|
|
P = None |
|
|
if hasattr(ot_sampler, "coupling"): |
|
|
P = ot_sampler.coupling(x0s, x1s) |
|
|
elif hasattr(ot_sampler, "plan"): |
|
|
P = ot_sampler.plan(x0s, x1s) |
|
|
|
|
|
|
|
|
for b in range(B): |
|
|
x0_b = x0s[b:b+1] |
|
|
|
|
|
if P is not None: |
|
|
|
|
|
probs = P[b].clamp_min(0) |
|
|
probs = probs / probs.sum().clamp_min(1e-12) |
|
|
if replace: |
|
|
j = torch.multinomial(probs, num_samples=N, replacement=True) |
|
|
else: |
|
|
k = min(N, (probs > 0).sum().item()) |
|
|
j = torch.multinomial(probs, num_samples=k, replacement=False) |
|
|
if k < N: |
|
|
j = torch.cat([j, j[-1:].expand(N-k)], dim=0) |
|
|
x1_match = x1s[j] |
|
|
else: |
|
|
|
|
|
|
|
|
got = False |
|
|
if hasattr(ot_sampler, "sample_plan"): |
|
|
try: |
|
|
|
|
|
x0_rep, x1_match = ot_sampler.sample_plan( |
|
|
x0_b, x1s, replace=replace, n_pairs=N |
|
|
) |
|
|
|
|
|
x1_match = x1_match.view(N, G) |
|
|
got = True |
|
|
except TypeError: |
|
|
pass |
|
|
if not got: |
|
|
|
|
|
xs, ys, js = [], [], [] |
|
|
for _ in range(N): |
|
|
x0_rep, x1_one = ot_sampler.sample_plan(x0_b, x1s, replace=replace) |
|
|
|
|
|
j_hat = torch.cdist(x1_one.view(1, -1), x1s).argmin() |
|
|
xs.append(x0_rep.view(1, G)) |
|
|
ys.append(x1_one.view(1, G)) |
|
|
js.append(j_hat.view(1)) |
|
|
x1_match = torch.cat(ys, dim=0) |
|
|
j = torch.cat(js, dim=0) |
|
|
|
|
|
|
|
|
|
|
|
x1_clusters[b] = x1_match |
|
|
idx1[b] = j |
|
|
|
|
|
return x0_clusters, x1_clusters, idx1 |
|
|
|
|
|
|
|
|
def load_config(path): |
|
|
with open(path, "r") as file: |
|
|
config = yaml.safe_load(file) |
|
|
return config |
|
|
|
|
|
|
|
|
def merge_config(args, config_updates): |
|
|
for key, value in config_updates.items(): |
|
|
if not hasattr(args, key): |
|
|
raise ValueError( |
|
|
f"Unknown configuration parameter '{key}' found in the config file." |
|
|
) |
|
|
setattr(args, key, value) |
|
|
return args |
|
|
|
|
|
|
|
|
def generate_group_string(length=16): |
|
|
alphabet = string.ascii_letters + string.digits |
|
|
return "".join(secrets.choice(alphabet) for _ in range(length)) |