Spaces:
Build error
Build error
File size: 11,505 Bytes
dee645c 4b0990a dee645c 4b0990a dee645c 4b0990a dee645c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
from transformers import set_seed
from tqdm.auto import trange
from PIL import Image
import numpy as np
import random
import utils
import torch
CONFIG_SPEC = [
("General", [
("text", "A cloud at dawn", str),
("iterations", 5000, (0, 7500)),
("seed", 12, int),
]),
("Rendering", [
("w", 224, [224, 252]),
("h", 224, [224, 252]),
# ("show_every", 50, int),
("showoff", 5000, (0, 10000)),
("turns", 4, int),
("focal_length", 0.1, float),
("plane_width", 0.1, float),
("shade_strength", 0.25, float),
("gamma", 0.5, float),
("max_depth", 7, float),
("offset", 5, float),
("offset_random", 0.75, float),
("xyz_random", 0.25, float),
("altitude_range", 0.3, float),
("augments", 4, int),
]),
("Optimization", [
("epochs", 1, int),
("lr", 0.5, float),
#@markdown CLIP loss type, might improve the results
("loss_type", "spherical", ["spherical", "cosine"]),
#@markdown CLIP loss weight
("clip_weight", 1.0, float), #@param {type: "number"}
]),
("Elements", [
("num_objects", 256, int),
#@markdown Number of dimensions. 0 is for point clouds (default), 1 will make
#@markdown strokes, 2 will make planes, 3 produces little cubes
("ndim", 0, [0, 1, 2, 3]), #@param {type: "integer"}
#@markdown Opacity scale:
("min_opacity", 1e-4, float), #@param {type: "number"}
("max_opacity", 1.0, float), #@param {type: "number"}
("log_opacity", False, bool), #@param {type: "boolean"}
("min_radius", 0.030, float),
("max_radius", 0.070, float),
("log_radius", False, bool),
# TODO dynamically decide bezier_res
#@markdown Bezier resolution: how many points a line/plane/cube will have. Not applicable to points
("bezier_res", 8, int), #@param {type: "integer"}
#@markdown Maximum scale of parameters: position, velocity, acceleration
("pos_scale", 0.4, float), #@param {type: "number"}
("vel_scale", 0.15, float), #@param {type: "number"}
("acc_scale", 0.15, float), #@param {type: "number"}
#@markdown Scale of each individual 3D object. Master control for velocity and acceleration scale.
("scale", 1, float), #@param {type: "number"}
]),
]
# TODO: one day separate the config into multiple parts and split this megaobject into multiple objects
# 2022/08/09: halfway done
class PulsarCLIP(object):
def __init__(self, args):
args = DotDict(**args)
set_seed(args.seed)
self.args = args
self.device = args.get("device", "cuda" if torch.cuda.is_available() else "cpu")
# Defer the import so that we can import `pulsar_clip` and then install `pytorch3d`
import pytorch3d.renderer.points.pulsar as ps
self.ndim = int(self.args.ndim)
self.renderer = ps.Renderer(self.args.w, self.args.h,
self.args.num_objects * (self.args.bezier_res ** self.ndim)).to(self.device)
self.bezier_pos = torch.nn.Parameter(torch.randn((args.num_objects, 4)).to(self.device))
self.bezier_vel = torch.nn.Parameter(torch.randn((args.num_objects, 3 * self.ndim)).to(self.device))
self.bezier_acc = torch.nn.Parameter(torch.randn((args.num_objects, 3 * self.ndim)).to(self.device))
self.bezier_col = torch.nn.Parameter(torch.randn((args.num_objects, 4 * (1 + self.ndim))).to(self.device))
self.optimizer = torch.optim.Adam([dict(params=[self.bezier_col], lr=5e-1 * args.lr),
dict(params=[self.bezier_pos], lr=1e-1 * args.lr),
dict(params=[self.bezier_vel, self.bezier_acc], lr=5e-2 * args.lr),
])
self.model_clip, self.preprocess_clip = utils.load_clip()
self.model_clip.visual.requires_grad_(False)
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer,
int(self.args.iterations
/ self.args.augments
/ self.args.epochs))
import clip
self.txt_emb = self.model_clip.encode_text(clip.tokenize([self.args.text]).to(self.device))[0].detach()
self.txt_emb = torch.nn.functional.normalize(self.txt_emb, dim=-1)
def get_points(self):
if self.ndim > 0:
bezier_ts = torch.stack(torch.meshgrid(
(torch.linspace(0, 1, self.args.bezier_res, device=self.device),) * self.ndim), dim=0
).unsqueeze(1).repeat((1, self.args.num_objects) + (1,) * self.ndim).unsqueeze(-1)
def interpolate_3D(pos, vel=0.0, acc=0.0, pos_scale=None, vel_scale=None, acc_scale=None, scale=None):
pos_scale = self.args.pos_scale if pos_scale is None else pos_scale
vel_scale = self.args.vel_scale if vel_scale is None else vel_scale
acc_scale = self.args.acc_scale if acc_scale is None else acc_scale
scale = self.args.scale if scale is None else scale
if self.ndim == 0:
return pos * pos_scale
result = 0.0
s = pos.shape[-1]
assert s * self.ndim == vel.shape[-1] == acc.shape[-1]
# O(dim) sequential lol
for d, bezier_t in zip(range(self.ndim), bezier_ts): # TODO replace with fused dimension operation
result = (result
+ torch.tanh(vel[..., d * s:(d + 1) * s]).view(
(-1,) + (1,) * self.ndim + (s,)) * vel_scale * bezier_t
+ torch.tanh(acc[..., d * s:(d + 1) * s]).view(
(-1,) + (1,) * self.ndim + (s,)) * acc_scale * bezier_t.pow(2))
result = (result * scale
+ torch.tanh(pos[..., :s]).view((-1,) + (1,) * self.ndim + (s,)) * pos_scale).view(-1, s)
return result
vert_pos = interpolate_3D(self.bezier_pos[..., :3], self.bezier_vel, self.bezier_acc)
vert_col = interpolate_3D(self.bezier_col[..., :4],
self.bezier_col[..., 4:4 + 4 * self.ndim],
self.bezier_col[..., -4 * self.ndim:])
to_bezier = lambda x: x.view((-1,) + (1,) * self.ndim + (x.shape[-1],)).repeat(
(1,) + (self.args.bezier_res,) * self.ndim + (1,)).reshape(-1, x.shape[-1])
rescale = lambda x, a, b, is_log=False: (torch.exp(x
* np.log(b / a)
+ np.log(a))) if is_log else x * (b - a) + a
return (
vert_pos,
torch.sigmoid(vert_col[..., :3]),
rescale(
torch.sigmoid(to_bezier(self.bezier_pos[..., -1:])[..., 0]),
self.args.min_radius, self.args.max_radius, is_log=self.args.log_radius
),
rescale(torch.sigmoid(vert_col[..., -1]),
self.args.min_opacity, self.args.max_opacity, is_log=self.args.log_opacity))
def camera(self, angle, altitude=0.0, offset=None, use_random=True, offset_random=None,
xyz_random=None, focal_length=None, plane_width=None):
if offset is None:
offset = self.args.offset
if xyz_random is None:
xyz_random = self.args.xyz_random
if focal_length is None:
focal_length = self.args.focal_length
if plane_width is None:
plane_width = self.args.plane_width
if offset_random is None:
offset_random = self.args.offset_random
device = self.device
offset = offset + np.random.normal() * offset_random * int(use_random)
position = torch.tensor([0, 0, -offset], dtype=torch.float)
position = utils.rotate_axis(position, altitude, 0)
position = utils.rotate_axis(position, angle, 1)
position = position + torch.randn(3) * xyz_random * int(use_random)
return torch.tensor([position[0], position[1], position[2],
altitude, angle, 0,
focal_length, plane_width], dtype=torch.float, device=device)
def render(self, cam_params=None):
if cam_params is None:
cam_params = self.camera(0, 0)
vert_pos, vert_col, radius, opacity = self.get_points()
rgb = self.renderer(vert_pos, vert_col, radius, cam_params,
self.args.gamma, self.args.max_depth, opacity=opacity)
opacity = self.renderer(vert_pos, vert_col * 0, radius, cam_params,
self.args.gamma, self.args.max_depth, opacity=opacity)
return rgb, opacity
def random_view_render(self):
angle = random.uniform(0, np.pi * 2)
altitude = random.uniform(-self.args.altitude_range / 2, self.args.altitude_range / 2)
cam_params = self.camera(angle, altitude)
result, alpha = self.render(cam_params)
back = torch.zeros_like(result)
s = back.shape
for j in range(s[-1]):
n = random.choice([7, 14, 28])
back[..., j] = utils.rand_perlin_2d_octaves(s[:-1], (n, n)).clip(-0.5, 0.5) + 0.5
result = result * (1 - alpha) + back * alpha
return result
def generate(self):
self.optimizer.zero_grad()
try:
for i in trange(self.args.iterations + self.args.showoff):
if i < self.args.iterations:
result = self.random_view_render()
img_emb = self.model_clip.encode_image(
self.preprocess_clip(result.permute(2, 0, 1)).unsqueeze(0).clamp(0., 1.))
img_emb = torch.nn.functional.normalize(img_emb, dim=-1)
if self.args.loss_type == "spherical":
clip_loss = (img_emb - self.txt_emb).norm(dim=-1).div(2).arcsin().pow(2).mul(2).mean()
elif self.args.loss_type == "cosine":
clip_loss = (1 - img_emb @ self.txt_emb.T).mean()
else:
raise NotImplementedError(f"CLIP loss type not supported: {self.args.loss_type}")
loss = clip_loss * self.args.clip_weight + (0 and ...) # TODO add more loss types
loss.backward()
if i % self.args.augments == self.args.augments - 1:
self.optimizer.step()
self.optimizer.zero_grad()
try:
self.scheduler.step()
except AttributeError:
pass
#if i % self.args.show_every == 0:
#cam_params = self.camera(i / self.args.iterations * np.pi * 2 * self.args.turns, use_random=False)
#img_show, _ = self.render(cam_params)
#img = Image.fromarray((img_show.cpu().detach().numpy() * 255).astype(np.uint8))
#yield img
except KeyboardInterrupt:
pass
class DotDict(dict):
def __getattr__(self, item):
return self.__getitem__(item)
|