|
import os |
|
import cv2 |
|
import time |
|
import tqdm |
|
import numpy as np |
|
|
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
import trimesh |
|
import rembg |
|
|
|
from cam_utils import orbit_camera, OrbitCamera |
|
from mesh_renderer import Renderer |
|
|
|
|
|
|
|
class GUI: |
|
def __init__(self, opt): |
|
self.opt = opt |
|
self.gui = opt.gui |
|
self.W = opt.W |
|
self.H = opt.H |
|
self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy) |
|
|
|
self.mode = "image" |
|
self.seed = "random" |
|
|
|
self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32) |
|
self.need_update = True |
|
|
|
|
|
self.device = torch.device("cuda") |
|
self.bg_remover = None |
|
|
|
self.guidance_sd = None |
|
self.guidance_zero123 = None |
|
|
|
self.enable_sd = False |
|
self.enable_zero123 = False |
|
|
|
|
|
self.renderer = Renderer(opt).to(self.device) |
|
|
|
|
|
self.input_img = None |
|
self.input_mask = None |
|
self.input_img_torch = None |
|
self.input_mask_torch = None |
|
self.overlay_input_img = False |
|
self.overlay_input_img_ratio = 0.5 |
|
|
|
|
|
self.prompt = "" |
|
self.negative_prompt = "" |
|
|
|
|
|
self.training = False |
|
self.optimizer = None |
|
self.step = 0 |
|
self.train_steps = 1 |
|
|
|
|
|
|
|
if self.opt.input is not None: |
|
self.load_input(self.opt.input) |
|
|
|
|
|
if self.opt.prompt is not None: |
|
self.prompt = self.opt.prompt |
|
|
|
if self.gui: |
|
dpg.create_context() |
|
self.register_dpg() |
|
self.test_step() |
|
|
|
def __del__(self): |
|
if self.gui: |
|
dpg.destroy_context() |
|
|
|
def seed_everything(self): |
|
try: |
|
seed = int(self.seed) |
|
except: |
|
seed = np.random.randint(0, 1000000) |
|
|
|
os.environ["PYTHONHASHSEED"] = str(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = True |
|
|
|
self.last_seed = seed |
|
|
|
def prepare_train(self): |
|
|
|
self.step = 0 |
|
|
|
|
|
self.optimizer = torch.optim.Adam(self.renderer.get_params()) |
|
|
|
|
|
pose = orbit_camera(self.opt.elevation, 0, self.opt.radius) |
|
self.fixed_cam = (pose, self.cam.perspective) |
|
|
|
|
|
self.enable_sd = self.opt.lambda_sd > 0 and self.prompt != "" |
|
self.enable_zero123 = self.opt.lambda_zero123 > 0 and self.input_img is not None |
|
|
|
|
|
if self.guidance_sd is None and self.enable_sd: |
|
print(f"[INFO] loading SD...") |
|
from guidance.sd_utils import StableDiffusion |
|
self.guidance_sd = StableDiffusion(self.device) |
|
print(f"[INFO] loaded SD!") |
|
|
|
if self.guidance_zero123 is None and self.enable_zero123: |
|
print(f"[INFO] loading zero123...") |
|
from guidance.zero123_utils import Zero123 |
|
self.guidance_zero123 = Zero123(self.device) |
|
print(f"[INFO] loaded zero123!") |
|
|
|
|
|
if self.input_img is not None: |
|
self.input_img_torch = torch.from_numpy(self.input_img).permute(2, 0, 1).unsqueeze(0).to(self.device) |
|
self.input_img_torch = F.interpolate( |
|
self.input_img_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False |
|
) |
|
|
|
self.input_mask_torch = torch.from_numpy(self.input_mask).permute(2, 0, 1).unsqueeze(0).to(self.device) |
|
self.input_mask_torch = F.interpolate( |
|
self.input_mask_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False |
|
) |
|
self.input_img_torch_channel_last = self.input_img_torch[0].permute(1,2,0).contiguous() |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
if self.enable_sd: |
|
self.guidance_sd.get_text_embeds([self.prompt], [self.negative_prompt]) |
|
|
|
if self.enable_zero123: |
|
self.guidance_zero123.get_img_embeds(self.input_img_torch) |
|
|
|
def train_step(self): |
|
starter = torch.cuda.Event(enable_timing=True) |
|
ender = torch.cuda.Event(enable_timing=True) |
|
starter.record() |
|
|
|
|
|
for _ in range(self.train_steps): |
|
|
|
self.step += 1 |
|
step_ratio = min(1, self.step / self.opt.iters_refine) |
|
|
|
loss = 0 |
|
|
|
|
|
if self.input_img_torch is not None: |
|
|
|
ssaa = min(2.0, max(0.125, 2 * np.random.random())) |
|
out = self.renderer.render(*self.fixed_cam, self.opt.ref_size, self.opt.ref_size, ssaa=ssaa) |
|
|
|
|
|
image = out["image"] |
|
valid_mask = ((out["alpha"] > 0) & (out["viewcos"] > 0.5)).detach() |
|
loss = loss + F.mse_loss(image * valid_mask, self.input_img_torch_channel_last * valid_mask) |
|
|
|
|
|
render_resolution = 512 |
|
images = [] |
|
vers, hors, radii = [], [], [] |
|
|
|
min_ver = max(min(-30, -30 - self.opt.elevation), -80 - self.opt.elevation) |
|
max_ver = min(max(30, 30 - self.opt.elevation), 80 - self.opt.elevation) |
|
for _ in range(self.opt.batch_size): |
|
|
|
|
|
ver = np.random.randint(min_ver, max_ver) |
|
hor = np.random.randint(-180, 180) |
|
radius = 0 |
|
|
|
vers.append(ver) |
|
hors.append(hor) |
|
radii.append(radius) |
|
|
|
pose = orbit_camera(self.opt.elevation + ver, hor, self.opt.radius + radius) |
|
|
|
|
|
ssaa = min(2.0, max(0.125, 2 * np.random.random())) |
|
out = self.renderer.render(pose, self.cam.perspective, render_resolution, render_resolution, ssaa=ssaa) |
|
|
|
image = out["image"] |
|
image = image.permute(2,0,1).contiguous().unsqueeze(0) |
|
|
|
images.append(image) |
|
|
|
images = torch.cat(images, dim=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.enable_sd: |
|
|
|
|
|
refined_images = self.guidance_sd.refine(images, strength=0.6).float() |
|
refined_images = F.interpolate(refined_images, (render_resolution, render_resolution), mode="bilinear", align_corners=False) |
|
loss = loss + self.opt.lambda_sd * F.mse_loss(images, refined_images) |
|
|
|
if self.enable_zero123: |
|
|
|
refined_images = self.guidance_zero123.refine(images, vers, hors, radii, strength=0.6).float() |
|
refined_images = F.interpolate(refined_images, (render_resolution, render_resolution), mode="bilinear", align_corners=False) |
|
loss = loss + self.opt.lambda_zero123 * F.mse_loss(images, refined_images) |
|
|
|
|
|
|
|
loss.backward() |
|
self.optimizer.step() |
|
self.optimizer.zero_grad() |
|
|
|
ender.record() |
|
torch.cuda.synchronize() |
|
t = starter.elapsed_time(ender) |
|
|
|
self.need_update = True |
|
|
|
if self.gui: |
|
dpg.set_value("_log_train_time", f"{t:.4f}ms") |
|
dpg.set_value( |
|
"_log_train_log", |
|
f"step = {self.step: 5d} (+{self.train_steps: 2d}) loss = {loss.item():.4f}", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def test_step(self): |
|
|
|
if not self.need_update: |
|
return |
|
|
|
starter = torch.cuda.Event(enable_timing=True) |
|
ender = torch.cuda.Event(enable_timing=True) |
|
starter.record() |
|
|
|
|
|
if self.need_update: |
|
|
|
|
|
out = self.renderer.render(self.cam.pose, self.cam.perspective, self.H, self.W) |
|
|
|
buffer_image = out[self.mode] |
|
|
|
if self.mode in ['depth', 'alpha']: |
|
buffer_image = buffer_image.repeat(1, 1, 3) |
|
if self.mode == 'depth': |
|
buffer_image = (buffer_image - buffer_image.min()) / (buffer_image.max() - buffer_image.min() + 1e-20) |
|
|
|
self.buffer_image = buffer_image.contiguous().clamp(0, 1).detach().cpu().numpy() |
|
|
|
|
|
if self.overlay_input_img and self.input_img is not None: |
|
self.buffer_image = ( |
|
self.buffer_image * (1 - self.overlay_input_img_ratio) |
|
+ self.input_img * self.overlay_input_img_ratio |
|
) |
|
|
|
self.need_update = False |
|
|
|
ender.record() |
|
torch.cuda.synchronize() |
|
t = starter.elapsed_time(ender) |
|
|
|
if self.gui: |
|
dpg.set_value("_log_infer_time", f"{t:.4f}ms ({int(1000/t)} FPS)") |
|
dpg.set_value( |
|
"_texture", self.buffer_image |
|
) |
|
|
|
|
|
def load_input(self, file): |
|
|
|
print(f'[INFO] load image from {file}...') |
|
img = cv2.imread(file, cv2.IMREAD_UNCHANGED) |
|
if img.shape[-1] == 3: |
|
if self.bg_remover is None: |
|
self.bg_remover = rembg.new_session() |
|
img = rembg.remove(img, session=self.bg_remover) |
|
|
|
img = cv2.resize( |
|
img, (self.W, self.H), interpolation=cv2.INTER_AREA |
|
) |
|
img = img.astype(np.float32) / 255.0 |
|
|
|
self.input_mask = img[..., 3:] |
|
|
|
self.input_img = img[..., :3] * self.input_mask + ( |
|
1 - self.input_mask |
|
) |
|
|
|
self.input_img = self.input_img[..., ::-1].copy() |
|
|
|
|
|
file_prompt = file.replace("_rgba.png", "_caption.txt") |
|
if os.path.exists(file_prompt): |
|
print(f'[INFO] load prompt from {file_prompt}...') |
|
with open(file_prompt, "r") as f: |
|
self.prompt = f.read().strip() |
|
|
|
def save_model(self): |
|
os.makedirs(self.opt.outdir, exist_ok=True) |
|
|
|
path = os.path.join(self.opt.outdir, self.opt.save_path + '.' + self.opt.mesh_format) |
|
self.renderer.export_mesh(path) |
|
|
|
print(f"[INFO] save model to {path}.") |
|
|
|
def register_dpg(self): |
|
|
|
|
|
with dpg.texture_registry(show=False): |
|
dpg.add_raw_texture( |
|
self.W, |
|
self.H, |
|
self.buffer_image, |
|
format=dpg.mvFormat_Float_rgb, |
|
tag="_texture", |
|
) |
|
|
|
|
|
|
|
|
|
with dpg.window( |
|
tag="_primary_window", |
|
width=self.W, |
|
height=self.H, |
|
pos=[0, 0], |
|
no_move=True, |
|
no_title_bar=True, |
|
no_scrollbar=True, |
|
): |
|
|
|
dpg.add_image("_texture") |
|
|
|
|
|
|
|
|
|
with dpg.window( |
|
label="Control", |
|
tag="_control_window", |
|
width=600, |
|
height=self.H, |
|
pos=[self.W, 0], |
|
no_move=True, |
|
no_title_bar=True, |
|
): |
|
|
|
with dpg.theme() as theme_button: |
|
with dpg.theme_component(dpg.mvButton): |
|
dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18)) |
|
dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47)) |
|
dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83)) |
|
dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5) |
|
dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3) |
|
|
|
|
|
with dpg.group(horizontal=True): |
|
dpg.add_text("Infer time: ") |
|
dpg.add_text("no data", tag="_log_infer_time") |
|
|
|
def callback_setattr(sender, app_data, user_data): |
|
setattr(self, user_data, app_data) |
|
|
|
|
|
with dpg.collapsing_header(label="Initialize", default_open=True): |
|
|
|
|
|
def callback_set_seed(sender, app_data): |
|
self.seed = app_data |
|
self.seed_everything() |
|
|
|
dpg.add_input_text( |
|
label="seed", |
|
default_value=self.seed, |
|
on_enter=True, |
|
callback=callback_set_seed, |
|
) |
|
|
|
|
|
def callback_select_input(sender, app_data): |
|
|
|
for k, v in app_data["selections"].items(): |
|
dpg.set_value("_log_input", k) |
|
self.load_input(v) |
|
|
|
self.need_update = True |
|
|
|
with dpg.file_dialog( |
|
directory_selector=False, |
|
show=False, |
|
callback=callback_select_input, |
|
file_count=1, |
|
tag="file_dialog_tag", |
|
width=700, |
|
height=400, |
|
): |
|
dpg.add_file_extension("Images{.jpg,.jpeg,.png}") |
|
|
|
with dpg.group(horizontal=True): |
|
dpg.add_button( |
|
label="input", |
|
callback=lambda: dpg.show_item("file_dialog_tag"), |
|
) |
|
dpg.add_text("", tag="_log_input") |
|
|
|
|
|
with dpg.group(horizontal=True): |
|
|
|
def callback_toggle_overlay_input_img(sender, app_data): |
|
self.overlay_input_img = not self.overlay_input_img |
|
self.need_update = True |
|
|
|
dpg.add_checkbox( |
|
label="overlay image", |
|
default_value=self.overlay_input_img, |
|
callback=callback_toggle_overlay_input_img, |
|
) |
|
|
|
def callback_set_overlay_input_img_ratio(sender, app_data): |
|
self.overlay_input_img_ratio = app_data |
|
self.need_update = True |
|
|
|
dpg.add_slider_float( |
|
label="ratio", |
|
min_value=0, |
|
max_value=1, |
|
format="%.1f", |
|
default_value=self.overlay_input_img_ratio, |
|
callback=callback_set_overlay_input_img_ratio, |
|
) |
|
|
|
|
|
|
|
dpg.add_input_text( |
|
label="prompt", |
|
default_value=self.prompt, |
|
callback=callback_setattr, |
|
user_data="prompt", |
|
) |
|
|
|
dpg.add_input_text( |
|
label="negative", |
|
default_value=self.negative_prompt, |
|
callback=callback_setattr, |
|
user_data="negative_prompt", |
|
) |
|
|
|
|
|
with dpg.group(horizontal=True): |
|
dpg.add_text("Save: ") |
|
|
|
dpg.add_button( |
|
label="model", |
|
tag="_button_save_model", |
|
callback=self.save_model, |
|
) |
|
dpg.bind_item_theme("_button_save_model", theme_button) |
|
|
|
dpg.add_input_text( |
|
label="", |
|
default_value=self.opt.save_path, |
|
callback=callback_setattr, |
|
user_data="save_path", |
|
) |
|
|
|
|
|
with dpg.collapsing_header(label="Train", default_open=True): |
|
|
|
with dpg.group(horizontal=True): |
|
dpg.add_text("Train: ") |
|
|
|
def callback_train(sender, app_data): |
|
if self.training: |
|
self.training = False |
|
dpg.configure_item("_button_train", label="start") |
|
else: |
|
self.prepare_train() |
|
self.training = True |
|
dpg.configure_item("_button_train", label="stop") |
|
|
|
|
|
|
|
|
|
|
|
|
|
dpg.add_button( |
|
label="start", tag="_button_train", callback=callback_train |
|
) |
|
dpg.bind_item_theme("_button_train", theme_button) |
|
|
|
with dpg.group(horizontal=True): |
|
dpg.add_text("", tag="_log_train_time") |
|
dpg.add_text("", tag="_log_train_log") |
|
|
|
|
|
with dpg.collapsing_header(label="Rendering", default_open=True): |
|
|
|
def callback_change_mode(sender, app_data): |
|
self.mode = app_data |
|
self.need_update = True |
|
|
|
dpg.add_combo( |
|
("image", "depth", "alpha", "normal"), |
|
label="mode", |
|
default_value=self.mode, |
|
callback=callback_change_mode, |
|
) |
|
|
|
|
|
def callback_set_fovy(sender, app_data): |
|
self.cam.fovy = np.deg2rad(app_data) |
|
self.need_update = True |
|
|
|
dpg.add_slider_int( |
|
label="FoV (vertical)", |
|
min_value=1, |
|
max_value=120, |
|
format="%d deg", |
|
default_value=np.rad2deg(self.cam.fovy), |
|
callback=callback_set_fovy, |
|
) |
|
|
|
|
|
|
|
def callback_camera_drag_rotate_or_draw_mask(sender, app_data): |
|
if not dpg.is_item_focused("_primary_window"): |
|
return |
|
|
|
dx = app_data[1] |
|
dy = app_data[2] |
|
|
|
self.cam.orbit(dx, dy) |
|
self.need_update = True |
|
|
|
def callback_camera_wheel_scale(sender, app_data): |
|
if not dpg.is_item_focused("_primary_window"): |
|
return |
|
|
|
delta = app_data |
|
|
|
self.cam.scale(delta) |
|
self.need_update = True |
|
|
|
def callback_camera_drag_pan(sender, app_data): |
|
if not dpg.is_item_focused("_primary_window"): |
|
return |
|
|
|
dx = app_data[1] |
|
dy = app_data[2] |
|
|
|
self.cam.pan(dx, dy) |
|
self.need_update = True |
|
|
|
def callback_set_mouse_loc(sender, app_data): |
|
if not dpg.is_item_focused("_primary_window"): |
|
return |
|
|
|
|
|
self.mouse_loc = np.array(app_data) |
|
|
|
with dpg.handler_registry(): |
|
|
|
dpg.add_mouse_drag_handler( |
|
button=dpg.mvMouseButton_Left, |
|
callback=callback_camera_drag_rotate_or_draw_mask, |
|
) |
|
dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale) |
|
dpg.add_mouse_drag_handler( |
|
button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan |
|
) |
|
|
|
dpg.create_viewport( |
|
title="Gaussian3D", |
|
width=self.W + 600, |
|
height=self.H + (45 if os.name == "nt" else 0), |
|
resizable=False, |
|
) |
|
|
|
|
|
with dpg.theme() as theme_no_padding: |
|
with dpg.theme_component(dpg.mvAll): |
|
|
|
dpg.add_theme_style( |
|
dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core |
|
) |
|
dpg.add_theme_style( |
|
dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core |
|
) |
|
dpg.add_theme_style( |
|
dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core |
|
) |
|
|
|
dpg.bind_item_theme("_primary_window", theme_no_padding) |
|
|
|
dpg.setup_dearpygui() |
|
|
|
|
|
|
|
if os.path.exists("LXGWWenKai-Regular.ttf"): |
|
with dpg.font_registry(): |
|
with dpg.font("LXGWWenKai-Regular.ttf", 18) as default_font: |
|
dpg.bind_font(default_font) |
|
|
|
|
|
|
|
dpg.show_viewport() |
|
|
|
def render(self): |
|
assert self.gui |
|
while dpg.is_dearpygui_running(): |
|
|
|
if self.training: |
|
self.train_step() |
|
self.test_step() |
|
dpg.render_dearpygui_frame() |
|
|
|
|
|
def train(self, iters=500): |
|
if iters > 0: |
|
self.prepare_train() |
|
for i in tqdm.trange(iters): |
|
self.train_step() |
|
|
|
self.save_model() |
|
|
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
from omegaconf import OmegaConf |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--config", required=True, help="path to the yaml config file") |
|
args, extras = parser.parse_known_args() |
|
|
|
|
|
opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras)) |
|
|
|
|
|
if opt.mesh is None: |
|
default_path = os.path.join(opt.outdir, opt.save_path + '_mesh.' + opt.mesh_format) |
|
if os.path.exists(default_path): |
|
opt.mesh = default_path |
|
else: |
|
raise ValueError(f"Cannot find mesh from {default_path}, must specify --mesh explicitly!") |
|
|
|
gui = GUI(opt) |
|
|
|
if opt.gui: |
|
gui.render() |
|
else: |
|
gui.train(opt.iters_refine) |
|
|