Garment3dKabeer / loop.py
Ali Mohsin
update 20003
d172fbe
import kornia
import os
import sys
import pathlib
import logging
import yaml
import nvdiffrast.torch as dr
from easydict import EasyDict
# Apply torchvision compatibility fixes
try:
import torchvision
print(f"torchvision {torchvision.__version__} imported successfully")
except (RuntimeError, AttributeError) as e:
if "operator torchvision::nms does not exist" in str(e) or "extension" in str(e):
print("Applying torchvision compatibility fixes...")
# Apply the same fixes as in app.py
import types
if not hasattr(torch, 'ops'):
torch.ops = types.SimpleNamespace()
if not hasattr(torch.ops, 'torchvision'):
torch.ops.torchvision = types.SimpleNamespace()
# Create dummy functions for problematic operators
torchvision_ops = ['nms', 'roi_align', 'roi_pool', 'ps_roi_align', 'ps_roi_pool']
for op_name in torchvision_ops:
if not hasattr(torch.ops.torchvision, op_name):
if op_name == 'nms':
setattr(torch.ops.torchvision, op_name, lambda *args, **kwargs: torch.zeros(0, dtype=torch.int64))
else:
setattr(torch.ops.torchvision, op_name, lambda *args, **kwargs: torch.zeros(0))
# Try importing again
try:
import torchvision
print("torchvision imported successfully after fixes")
except Exception as e2:
print(f"torchvision still has issues, but continuing: {e2}")
else:
print(f"Other torchvision error: {e}")
except ImportError:
print("torchvision not available, continuing without it")
from NeuralJacobianFields import SourceMesh
from nvdiffmodeling.src import render
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from utilities.video import Video
from utilities.helpers import cosine_avg, create_scene, l1_avg
from utilities.camera import CameraBatch, get_camera_params
from utilities.clip_spatial import CLIPVisualEncoder
from utilities.resize_right import resize, cubic, linear, lanczos2, lanczos3
from packages.fashion_clip.fashion_clip.fashion_clip import FashionCLIP
from utils import *
from get_embeddings import *
from pytorch3d.structures import Meshes
from pytorch3d.loss import (
chamfer_distance,
mesh_edge_loss,
mesh_laplacian_smoothing,
mesh_normal_consistency,
)
from pytorch3d.ops import sample_points_from_meshes
def total_triangle_area(vertices):
# Calculate the sum of the areas of all triangles in the mesh
num_triangles = vertices.shape[0] // 3
triangle_vertices = vertices.view(num_triangles, 3, 3)
# Calculate the cross product for each triangle
cross_products = torch.cross(triangle_vertices[:, 1] - triangle_vertices[:, 0],
triangle_vertices[:, 2] - triangle_vertices[:, 0])
# Calculate the area of each triangle
areas = 0.5 * torch.norm(cross_products, dim=1)
# Sum the areas of all triangles
total_area = torch.sum(areas)
return total_area
def triangle_size_regularization(vertices):
# Penalize small triangles by minimizing the squared sum of triangle areas
return total_triangle_area(vertices)**2
def loop(cfg):
clip_flag = True
output_path = pathlib.Path(cfg['output_path'])
os.makedirs(output_path, exist_ok=True)
with open(output_path / 'config.yml', 'w') as f:
yaml.dump(cfg, f, default_flow_style=False)
cfg = EasyDict(cfg)
print(f'Output directory {cfg.output_path} created')
os.makedirs(output_path / 'tmp', exist_ok=True)
device = torch.device(f'cuda:{cfg.gpu}')
torch.cuda.set_device(device)
# Read mode flags from config if available, otherwise use defaults
text_input = cfg.get('text_input', False)
image_input = cfg.get('image_input', False)
fashion_image = cfg.get('fashion_image', False)
fashion_text = cfg.get('fashion_text', True) # Default to fashion text mode
use_target_mesh = cfg.get('use_target_mesh', True)
CLIP_embeddings = False # Always use FashionCLIP to avoid CLIP issues
# Always use FashionCLIP to avoid CLIP loading issues
print('Loading FashionCLIP model...')
try:
fclip = FashionCLIP('fashion-clip')
print('FashionCLIP loaded successfully')
except Exception as e:
print(f'Error loading FashionCLIP: {e}')
raise RuntimeError(f"Failed to load FashionCLIP: {e}")
# Load CLIPVisualEncoder with error handling
print('Loading CLIPVisualEncoder...')
try:
fe = CLIPVisualEncoder(cfg.consistency_clip_model, cfg.consistency_vit_stride, device)
print('CLIPVisualEncoder loaded successfully')
except Exception as e:
print(f'Error loading CLIPVisualEncoder: {e}')
print('Continuing without CLIPVisualEncoder...')
fe = None
# Use FashionCLIP for all modes to avoid CLIP loading issues
if fashion_image:
print('Processing with fashion image embeddings')
target_direction_embeds, delta_direction_embeds = get_fashion_img_embeddings(fclip, cfg, device, True)
elif fashion_text:
print('Processing with fashion text embeddings')
target_direction_embeds, delta_direction_embeds = get_fashion_text_embeddings(fclip, cfg, device)
elif text_input or image_input:
print('WARNING: Regular CLIP embeddings are disabled, using FashionCLIP instead')
if text_input:
target_direction_embeds, delta_direction_embeds = get_fashion_text_embeddings(fclip, cfg, device)
else:
target_direction_embeds, delta_direction_embeds = get_fashion_img_embeddings(fclip, cfg, device, True)
clip_mean = torch.tensor([0.48154660, 0.45782750, 0.40821073], device=device)
clip_std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=device)
# output video
video = Video(cfg.output_path)
# GL Context - with fallback for headless environments
print('Initializing nvdiffrast GL context...')
try:
glctx = dr.RasterizeGLContext()
print('nvdiffrast GL context initialized successfully')
use_gl_rendering = True
except Exception as e:
print(f'Error initializing nvdiffrast GL context: {e}')
print('This is likely due to missing EGL headers in headless environment.')
print('Using fallback rendering approach...')
glctx = None
use_gl_rendering = False
def fallback_render_mesh(mesh, mvp, campos, lightpos, light_power, resolution, **kwargs):
"""
Fallback rendering function when GL context is not available
Returns a simple colored mesh visualization
"""
try:
# Check if return_rast_map is requested
return_rast_map = kwargs.get('return_rast_map', False)
# Create a simple colored mesh visualization
# This is a basic fallback that creates a colored mesh without proper lighting
device = mesh.v_pos.device if hasattr(mesh, 'v_pos') and mesh.v_pos is not None else torch.device('cuda')
batch_size = 1
if return_rast_map:
# Return a dummy rasterization map for consistency
rast_map = torch.zeros(batch_size, resolution, resolution, 4, device=device)
rast_map[..., 3] = 1.0 # Set alpha to 1
return rast_map
else:
# Create a simple colored output
color = torch.ones(batch_size, resolution, resolution, 3, device=device) * 0.5 # Gray color
# Add some basic shading based on vertex positions
if hasattr(mesh, 'v_pos') and mesh.v_pos is not None:
# Normalize vertex positions for coloring
v_pos_norm = (mesh.v_pos - mesh.v_pos.min(dim=0)[0]) / (mesh.v_pos.max(dim=0)[0] - mesh.v_pos.min(dim=0)[0] + 1e-8)
# Use vertex positions to create a simple color gradient
color = color * 0.3 + v_pos_norm.mean(dim=0).unsqueeze(0).unsqueeze(0).unsqueeze(0) * 0.7
return color
except Exception as e:
print(f"Fallback rendering failed: {e}")
# Return a simple colored square as last resort
device = mesh.v_pos.device if hasattr(mesh, 'v_pos') and mesh.v_pos is not None else torch.device('cuda')
if kwargs.get('return_rast_map', False):
return torch.zeros(1, resolution, resolution, 4, device=device)
else:
return torch.ones(1, resolution, resolution, 3, device=device) * 0.5
def safe_render_mesh(glctx, mesh, mvp, campos, lightpos, light_power, resolution, **kwargs):
"""
Safe rendering function that uses GL context if available, otherwise falls back
"""
if glctx is not None and use_gl_rendering:
try:
return render.render_mesh(glctx, mesh, mvp, campos, lightpos, light_power, resolution, **kwargs)
except Exception as e:
print(f"GL rendering failed, using fallback: {e}")
return fallback_render_mesh(mesh, mvp, campos, lightpos, light_power, resolution, **kwargs)
else:
return fallback_render_mesh(mesh, mvp, campos, lightpos, light_power, resolution, **kwargs)
load_mesh = get_mesh(cfg.mesh, output_path, cfg.retriangulate, cfg.bsdf)
if use_target_mesh:
target_mesh = get_mesh(cfg.target_mesh, output_path, cfg.retriangulate, cfg.bsdf, 'mesh_target.obj')
# We construct a Meshes structure for the target mesh
trg_mesh_p3d = Meshes(verts=[target_mesh.v_pos], faces=[target_mesh.t_pos_idx])
jacobian_source = SourceMesh.SourceMesh(0, str(output_path / 'tmp' / 'mesh.obj'), {}, 1, ttype=torch.float)
if len(list((output_path / 'tmp').glob('*.npz'))) > 0:
logging.warn(f'Using existing Jacobian .npz files in {str(output_path)}/tmp/ ! Please check if this is intentional.')
# Check if the mesh file exists before loading
mesh_file_path = output_path / 'tmp' / 'mesh.obj'
print(f"Looking for mesh file at: {mesh_file_path}")
print(f"Absolute path: {mesh_file_path.absolute()}")
if not mesh_file_path.exists():
# List files in the tmp directory to see what's there
tmp_dir = output_path / 'tmp'
if tmp_dir.exists():
print(f"Files in {tmp_dir}:")
for file in tmp_dir.iterdir():
print(f" - {file.name}")
else:
print(f"Tmp directory {tmp_dir} does not exist")
raise FileNotFoundError(f"Mesh file not found: {mesh_file_path}. This indicates an issue with the mesh loading process.")
print(f"Mesh file exists at: {mesh_file_path}")
print("Loading jacobian source...")
jacobian_source.load()
jacobian_source.to(device)
# Validate that jacobian source loaded properly
if not hasattr(jacobian_source, 'jacobians_from_vertices') or jacobian_source.jacobians_from_vertices is None:
raise ValueError("Failed to load jacobian source. The jacobians_from_vertices method is not available.")
print("Jacobian source loaded successfully")
with torch.no_grad():
gt_jacobians = jacobian_source.jacobians_from_vertices(load_mesh.v_pos.unsqueeze(0))
# Validate that gt_jacobians is not empty
if gt_jacobians is None or gt_jacobians.shape[0] == 0:
raise ValueError("Failed to generate jacobians from vertices. This indicates an issue with the mesh or jacobian source.")
print(f"Generated jacobians with shape: {gt_jacobians.shape}")
gt_jacobians.requires_grad_(True)
optimizer = torch.optim.Adam([gt_jacobians], lr=cfg.lr)
cams_data = CameraBatch(
cfg.train_res,
[cfg.dist_min, cfg.dist_max],
[cfg.azim_min, cfg.azim_max],
[cfg.elev_alpha, cfg.elev_beta, cfg.elev_max],
[cfg.fov_min, cfg.fov_max],
cfg.aug_loc,
cfg.aug_light,
cfg.aug_bkg,
cfg.batch_size,
rand_solid=True
)
cams = torch.utils.data.DataLoader(cams_data, cfg.batch_size, num_workers=0, pin_memory=True)
best_losses = {'CLIP': np.inf, 'total': np.inf}
for out_type in ['final', 'best_clip', 'best_total', 'target_final']:
os.makedirs(output_path / f'mesh_{out_type}', exist_ok=True)
os.makedirs(output_path / 'images', exist_ok=True)
logger = SummaryWriter(str(output_path / 'logs'))
rot_ang = 0.0
t_loop = tqdm(range(cfg.epochs), leave=False)
if cfg.resize_method == 'cubic':
resize_method = cubic
elif cfg.resize_method == 'linear':
resize_method = linear
elif cfg.resize_method == 'lanczos2':
resize_method = lanczos2
elif cfg.resize_method == 'lanczos3':
resize_method = lanczos3
for it in t_loop:
# updated vertices from jacobians
n_vert = jacobian_source.vertices_from_jacobians(gt_jacobians).squeeze()
# Validate that n_vert is not empty
if n_vert is None or n_vert.shape[0] == 0:
raise ValueError("Generated vertices are empty. This indicates an issue with the jacobian source or mesh loading.")
print(f"Iteration {it}: Generated {n_vert.shape[0]} vertices")
# TODO: More texture code required to make it work ...
ready_texture = texture.Texture2D(
kornia.filters.gaussian_blur2d(
load_mesh.material['kd'].data.permute(0, 3, 1, 2),
kernel_size=(7, 7),
sigma=(3, 3),
).permute(0, 2, 3, 1).contiguous()
)
kd_notex = texture.Texture2D(torch.full_like(ready_texture.data, 0.5))
ready_specular = texture.Texture2D(
kornia.filters.gaussian_blur2d(
load_mesh.material['ks'].data.permute(0, 3, 1, 2),
kernel_size=(7, 7),
sigma=(3, 3),
).permute(0, 2, 3, 1).contiguous()
)
ready_normal = texture.Texture2D(
kornia.filters.gaussian_blur2d(
load_mesh.material['normal'].data.permute(0, 3, 1, 2),
kernel_size=(7, 7),
sigma=(3, 3),
).permute(0, 2, 3, 1).contiguous()
)
# Final mesh
m = mesh.Mesh(
n_vert,
load_mesh.t_pos_idx,
material={
'bsdf': cfg.bsdf,
'kd': kd_notex,
'ks': ready_specular,
'normal': ready_normal,
},
base=load_mesh # gets uvs etc from here
)
deformed_mesh_p3d = Meshes(verts=[m.v_pos], faces=[m.t_pos_idx])
render_mesh = create_scene([m.eval()], sz=512)
if it == 0:
base_mesh = render_mesh.clone()
base_mesh = mesh.auto_normals(base_mesh)
base_mesh = mesh.compute_tangents(base_mesh)
render_mesh = mesh.auto_normals(render_mesh)
render_mesh = mesh.compute_tangents(render_mesh)
if use_target_mesh:
# Target mesh
m_target = mesh.Mesh(
target_mesh.v_pos,
target_mesh.t_pos_idx,
material={
'bsdf': cfg.bsdf,
'kd': kd_notex,
'ks': ready_specular,
'normal': ready_normal,
},
base=target_mesh
)
render_target_mesh = create_scene([m_target.eval()], sz=512)
if it == 0:
base_target_mesh = render_target_mesh.clone()
base_target_mesh = mesh.auto_normals(base_target_mesh)
base_target_mesh = mesh.compute_tangents(base_target_mesh)
render_target_mesh = mesh.auto_normals(render_target_mesh)
render_target_mesh = mesh.compute_tangents(render_target_mesh)
# Logging mesh
if it % cfg.log_interval == 0:
with torch.no_grad():
params = get_camera_params(
cfg.log_elev,
rot_ang,
cfg.log_dist,
cfg.log_res,
cfg.log_fov,
)
rot_ang += 5
log_mesh = mesh.unit_size(render_mesh.eval(params))
log_image = safe_render_mesh(glctx, log_mesh, params['mvp'], params['campos'], params['lightpos'], cfg.log_light_power, cfg.log_res)
log_image = video.ready_image(log_image)
logger.add_mesh('predicted_mesh', vertices=log_mesh.v_pos.unsqueeze(0), faces=log_mesh.t_pos_idx.unsqueeze(0), global_step=it)
if cfg.adapt_dist and it > 0:
with torch.no_grad():
v_pos = m.v_pos.clone()
vmin = v_pos.amin(dim=0)
vmax = v_pos.amax(dim=0)
v_pos -= (vmin + vmax) / 2
mult = torch.cat([v_pos.amin(dim=0), v_pos.amax(dim=0)]).abs().amax().cpu()
cams.dataset.dist_min = cfg.dist_min * mult
cams.dataset.dist_max = cfg.dist_max * mult
params_camera = next(iter(cams))
for key in params_camera:
params_camera[key] = params_camera[key].to(device)
final_mesh = render_mesh.eval(params_camera)
train_render = safe_render_mesh(glctx, final_mesh, params_camera['mvp'], params_camera['campos'], params_camera['lightpos'], cfg.light_power, cfg.train_res)
# Handle permutation for fallback case
if train_render.shape[-1] == 3: # If it's already in the right format
train_render = train_render.permute(0, 3, 1, 2)
train_render = resize(train_render, out_shape=(224, 224), interp_method=resize_method)
if use_target_mesh:
final_target_mesh = render_target_mesh.eval(params_camera)
train_target_render = safe_render_mesh(glctx, final_target_mesh, params_camera['mvp'], params_camera['campos'], params_camera['lightpos'], cfg.light_power, cfg.train_res)
# Handle permutation for fallback case
if train_target_render.shape[-1] == 3: # If it's already in the right format
train_target_render = train_target_render.permute(0, 3, 1, 2)
train_target_render = resize(train_target_render, out_shape=(224, 224), interp_method=resize_method)
train_rast_map = safe_render_mesh(
glctx,
final_mesh,
params_camera['mvp'],
params_camera['campos'],
params_camera['lightpos'],
cfg.light_power,
cfg.train_res,
return_rast_map=True
)
if it == 0:
params_camera = next(iter(cams))
for key in params_camera:
params_camera[key] = params_camera[key].to(device)
base_render = safe_render_mesh(glctx, base_mesh.eval(params_camera), params_camera['mvp'], params_camera['campos'], params_camera['lightpos'], cfg.light_power, cfg.train_res)
# Handle permutation for fallback case
if base_render.shape[-1] == 3: # If it's already in the right format
base_render = base_render.permute(0, 3, 1, 2)
base_render = resize(base_render, out_shape=(224, 224), interp_method=resize_method)
if it % cfg.log_interval_im == 0:
log_idx = torch.randperm(cfg.batch_size)[:5]
s_log = train_render[log_idx, :, :, :]
s_log = torchvision.utils.make_grid(s_log)
ndarr = s_log.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = Image.fromarray(ndarr)
im.save(str(output_path / 'images' / f'epoch_{it}.png'))
if use_target_mesh:
s_log_target = train_target_render[log_idx, :, :, :]
s_log_target = torchvision.utils.make_grid(s_log_target)
ndarr = s_log_target.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = Image.fromarray(ndarr)
im.save(str(output_path / 'images' / f'epoch_{it}_target.png'))
obj.write_obj(
str(output_path / 'mesh_final'),
m.eval()
)
optimizer.zero_grad()
normalized_clip_render = (train_render - clip_mean[None, :, None, None]) / clip_std[None, :, None, None]
deformed_features = fclip.encode_image_tensors(train_render)
target_features = fclip.encode_image_tensors(train_target_render)
garment_loss = l1_avg(deformed_features, target_features)
l1_loss = l1_avg(train_render, train_target_render)
# We sample 10k points from the surface of each mesh
sample_src = sample_points_from_meshes(deformed_mesh_p3d, 10000)
sample_trg = sample_points_from_meshes(trg_mesh_p3d, 10000)
# We compare the two sets of pointclouds by computing (a) the chamfer loss
loss_chamfer, _ = chamfer_distance(sample_trg, sample_src)
loss_chamfer *= 25.
#
# and (b) the edge length of the predicted mesh
loss_edge = mesh_edge_loss(deformed_mesh_p3d)
# mesh normal consistency
loss_normal = mesh_normal_consistency(deformed_mesh_p3d)
# mesh laplacian smoothing
loss_laplacian = mesh_laplacian_smoothing(deformed_mesh_p3d, method="uniform")
loss_triangles = triangle_size_regularization(deformed_mesh_p3d.verts_list()[0])/100000.
logger.add_scalar('l1_loss', l1_loss, global_step=it)
logger.add_scalar('garment_loss', garment_loss, global_step=it)
# Jacobian regularization
r_loss = (((gt_jacobians) - torch.eye(3, 3, device=device)) ** 2).mean()
logger.add_scalar('jacobian_regularization', r_loss, global_step=it)
if cfg.consistency_loss_weight != 0 and fe is not None and train_rast_map is not None:
consistency_loss = compute_mv_cl(final_mesh, fe, normalized_clip_render, params_camera, train_rast_map, cfg, device)
else:
consistency_loss = r_loss
logger.add_scalar('consistency_loss', consistency_loss, global_step=it)
logger.add_scalar('chamfer', loss_chamfer, global_step=it)
logger.add_scalar('edge', loss_edge, global_step=it)
logger.add_scalar('normal', loss_normal, global_step=it)
logger.add_scalar('laplacian', loss_laplacian, global_step=it)
logger.add_scalar('triangles', loss_triangles, global_step=it)
if it > 1000 and clip_flag:
cfg.clip_weight = 0
cfg.consistency_loss_weight = 0
cfg.regularize_jacobians_weight = 0.025
clip_flag = False
regularizers = loss_chamfer + loss_edge + loss_normal + loss_laplacian + loss_triangles
total_loss = (cfg.clip_weight * garment_loss + cfg.delta_clip_weight * l1_loss +
cfg.regularize_jacobians_weight * r_loss + cfg.consistency_loss_weight * consistency_loss + regularizers)
logger.add_scalar('total_loss', total_loss, global_step=it)
total_loss.backward()
optimizer.step()
t_loop.set_description(
f'L1 = {cfg.delta_clip_weight * l1_loss.item()}, '
f'CLIP = {cfg.clip_weight * garment_loss.item()}, '
f'Jacb = {cfg.regularize_jacobians_weight * r_loss.item()}, '
f'MVC = {cfg.consistency_loss_weight * consistency_loss.item()}, '
f'Chamf = {loss_chamfer.item()}, '
f'Edge = {loss_edge.item()}, '
f'Normal = {loss_normal.item()}, '
f'Lapl = {loss_laplacian.item()}, '
f'Triang = {loss_triangles.item()}, '
f'Total = {total_loss.item()}')#_target
video.close()
obj.write_obj(
str(output_path / 'mesh_final'),
m.eval()
)
return