import enum import torch import math import fcbh.utils def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) return abs(a*b) // math.gcd(a, b) class CONDRegular: def __init__(self, cond): self.cond = cond def _copy_with(self, cond): return self.__class__(cond) def process_cond(self, batch_size, device, **kwargs): return self._copy_with(fcbh.utils.repeat_to_batch_size(self.cond, batch_size).to(device)) def can_concat(self, other): if self.cond.shape != other.cond.shape: return False return True def concat(self, others): conds = [self.cond] for x in others: conds.append(x.cond) return torch.cat(conds) class CONDNoiseShape(CONDRegular): def process_cond(self, batch_size, device, area, **kwargs): data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] return self._copy_with(fcbh.utils.repeat_to_batch_size(data, batch_size).to(device)) class CONDCrossAttn(CONDRegular): def can_concat(self, other): s1 = self.cond.shape s2 = other.cond.shape if s1 != s2: if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen return False mult_min = lcm(s1[1], s2[1]) diff = mult_min // min(s1[1], s2[1]) if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much return False return True def concat(self, others): conds = [self.cond] crossattn_max_len = self.cond.shape[1] for x in others: c = x.cond crossattn_max_len = lcm(crossattn_max_len, c.shape[1]) conds.append(c) out = [] for c in conds: if c.shape[1] < crossattn_max_len: c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result out.append(c) return torch.cat(out) class CONDConstant(CONDRegular): def __init__(self, cond): self.cond = cond def process_cond(self, batch_size, device, **kwargs): return self._copy_with(self.cond) def can_concat(self, other): if self.cond != other.cond: return False return True def concat(self, others): return self.cond