Spaces:
Paused
Paused
import os | |
import sys | |
import yaml | |
import torch | |
import random | |
import numpy as np | |
import gradio as gr | |
from pathlib import Path | |
import tempfile | |
import shutil | |
from PIL import Image | |
# Add the current directory to Python path | |
sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
# Add packages directory | |
packages_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'packages') | |
if os.path.exists(packages_dir): | |
sys.path.append(packages_dir) | |
# Fix for torchvision operator error in newer PyTorch versions | |
try: | |
import torch._custom_ops | |
if not hasattr(torch._custom_ops, "_register_all_aten_ops"): | |
# Add this attribute to avoid errors with newer PyTorch versions and torchvision | |
setattr(torch._custom_ops, "_register_all_aten_ops", lambda: None) | |
except: | |
pass | |
# Function to convert OBJ to GLB format | |
def convert_obj_to_glb(obj_file_path, glb_file_path=None): | |
""" | |
Convert OBJ file to GLB format using trimesh | |
Args: | |
obj_file_path: Path to the OBJ file | |
glb_file_path: Path for the output GLB file (optional) | |
Returns: | |
Path to the created GLB file or None if conversion failed | |
""" | |
try: | |
import trimesh | |
print(f"Converting {obj_file_path} to GLB format...") | |
# Check if input file exists | |
if not os.path.exists(obj_file_path): | |
print(f"Error: OBJ file {obj_file_path} does not exist") | |
return None | |
# Load the OBJ file | |
mesh = trimesh.load(obj_file_path) | |
# If no GLB path specified, create one in the same directory | |
if glb_file_path is None: | |
glb_file_path = str(Path(obj_file_path).with_suffix('.glb')) | |
# Export as GLB | |
mesh.export(glb_file_path, file_type='glb') | |
print(f"Successfully converted to GLB: {glb_file_path}") | |
return glb_file_path | |
except ImportError: | |
print("trimesh not available for GLB conversion") | |
return None | |
except Exception as e: | |
print(f"Error converting OBJ to GLB: {e}") | |
return None | |
# Function to ensure both OBJ and GLB files are created | |
def ensure_mesh_formats(output_dir): | |
""" | |
Ensure both OBJ and GLB files are available in the output directory | |
Args: | |
output_dir: Directory containing mesh files | |
Returns: | |
tuple: (obj_files, glb_files) - lists of available files | |
""" | |
obj_files = [] | |
glb_files = [] | |
output_path = Path(output_dir) | |
if not output_path.exists(): | |
print(f"Warning: Output directory {output_dir} does not exist") | |
return obj_files, glb_files | |
# Find all OBJ files | |
for obj_file in output_path.rglob("*.obj"): | |
obj_files.append(str(obj_file)) | |
print(f"Found OBJ file: {obj_file}") | |
# Try to create corresponding GLB file | |
glb_file = obj_file.with_suffix('.glb') | |
if not glb_file.exists(): | |
print(f"Creating GLB file for {obj_file}") | |
glb_path = convert_obj_to_glb(str(obj_file), str(glb_file)) | |
if glb_path: | |
glb_files.append(glb_path) | |
print(f"Successfully created GLB: {glb_path}") | |
else: | |
print(f"Failed to create GLB for {obj_file}") | |
else: | |
glb_files.append(str(glb_file)) | |
print(f"GLB file already exists: {glb_file}") | |
# Also check for existing GLB files that might not have corresponding OBJ | |
for glb_file in output_path.rglob("*.glb"): | |
if str(glb_file) not in glb_files: | |
glb_files.append(str(glb_file)) | |
print(f"Found standalone GLB file: {glb_file}") | |
print(f"Total files found: {len(obj_files)} OBJ files, {len(glb_files)} GLB files") | |
return obj_files, glb_files | |
# Check if complex dependencies are installed | |
def check_complex_dependencies(): | |
"""Check if complex dependencies are available""" | |
dependencies_ok = True | |
missing_deps = [] | |
try: | |
import nvdiffrast | |
print("β nvdiffrast available") | |
except ImportError: | |
print("β nvdiffrast not available") | |
dependencies_ok = False | |
missing_deps.append("nvdiffrast") | |
try: | |
import pytorch3d | |
print("β pytorch3d available") | |
except ImportError: | |
print("β pytorch3d not available") | |
dependencies_ok = False | |
missing_deps.append("pytorch3d") | |
# Check if torch-sparse is available or disabled | |
try: | |
import torch_sparse | |
print("β torch-sparse available") | |
except ImportError: | |
# Check if torch-sparse was disabled | |
try: | |
with open("NeuralJacobianFields/PoissonSystem.py", 'r') as f: | |
content = f.read() | |
if "USE_TORCH_SPARSE = False" in content: | |
print("β torch-sparse is disabled, using built-in PyTorch sparse operations") | |
else: | |
print("β torch-sparse not available") | |
missing_deps.append("torch-sparse") | |
except: | |
print("β torch-sparse not available") | |
missing_deps.append("torch-sparse") | |
# Check if torch-scatter is available (not critical) | |
try: | |
import torch_scatter | |
print("β torch-scatter available") | |
except ImportError: | |
print("β torch-scatter not available, but this may not be critical") | |
# Try to safely import torchvision | |
try: | |
import torchvision | |
print(f"β torchvision {torchvision.__version__} loaded successfully") | |
except RuntimeError as e: | |
if "operator torchvision::nms does not exist" in str(e): | |
print("β Compatibility issue with torchvision. Will attempt to continue anyway.") | |
else: | |
print(f"β torchvision error: {e}") | |
except ImportError: | |
print("β torchvision not available") | |
if missing_deps: | |
print(f"Missing dependencies: {', '.join(missing_deps)}") | |
return dependencies_ok | |
# Check dependencies but don't fail if some are missing | |
print("Checking dependencies...") | |
deps_ok = check_complex_dependencies() | |
if not deps_ok: | |
print("Some dependencies are missing, but continuing anyway...") | |
print("The app will start with limited functionality.") | |
print("You can install missing dependencies manually if needed.") | |
else: | |
print("All dependencies are available!") | |
# Enhanced torchvision compatibility handling | |
def apply_torchvision_fix(): | |
"""Apply comprehensive fix for torchvision compatibility issues""" | |
try: | |
import types | |
# Pre-emptively create torch.ops structure if needed | |
if not hasattr(torch, 'ops'): | |
torch.ops = types.SimpleNamespace() | |
if not hasattr(torch.ops, 'torchvision'): | |
torch.ops.torchvision = types.SimpleNamespace() | |
# Create dummy functions for all problematic torchvision operators | |
torchvision_ops = ['nms', 'roi_align', 'roi_pool', 'ps_roi_align', 'ps_roi_pool'] | |
for op_name in torchvision_ops: | |
if not hasattr(torch.ops.torchvision, op_name): | |
if op_name == 'nms': | |
setattr(torch.ops.torchvision, op_name, lambda *args, **kwargs: torch.zeros(0, dtype=torch.int64)) | |
else: | |
setattr(torch.ops.torchvision, op_name, lambda *args, **kwargs: torch.zeros(0)) | |
# Fix for torchvision extension issues | |
try: | |
import torchvision | |
if not hasattr(torchvision, 'extension'): | |
torchvision.extension = types.SimpleNamespace() | |
torchvision.extension._has_ops = lambda: False | |
except: | |
pass | |
# Fix for torchvision meta registrations | |
try: | |
if 'torchvision' in sys.modules: | |
torchvision = sys.modules['torchvision'] | |
if not hasattr(torchvision, '_meta_registrations'): | |
torchvision._meta_registrations = types.SimpleNamespace() | |
except: | |
pass | |
print("Applied comprehensive torchvision compatibility fixes") | |
return True | |
except Exception as e: | |
print(f"Failed to apply torchvision fixes: {e}") | |
return False | |
# Apply torchvision fix before any imports | |
apply_torchvision_fix() | |
# Custom import handling for loop module to handle dependency issues | |
loop = None | |
loop_import_error = None | |
def try_import_loop(): | |
"""Try to import the loop module with comprehensive error handling""" | |
global loop, loop_import_error | |
try: | |
# Apply torchvision fixes before any imports | |
apply_torchvision_fix() | |
# Try to import torchvision with error handling | |
try: | |
import torchvision | |
print(f"torchvision {torchvision.__version__} imported successfully") | |
except (RuntimeError, AttributeError) as e: | |
if "operator torchvision::nms does not exist" in str(e) or "extension" in str(e): | |
print("Detected torchvision compatibility issue. Applying additional fixes...") | |
# Re-apply fixes after the error | |
apply_torchvision_fix() | |
# Try importing again with sys.modules manipulation | |
try: | |
if 'torchvision' in sys.modules: | |
del sys.modules['torchvision'] | |
import torchvision | |
print("torchvision imported successfully after fixes") | |
except Exception as e2: | |
print(f"torchvision still has issues, but continuing: {e2}") | |
else: | |
print(f"Other torchvision error: {e}") | |
# Try to import required modules - these are critical for production | |
try: | |
import nvdiffrast | |
print("β nvdiffrast imported successfully") | |
except ImportError as e: | |
print(f"β nvdiffrast import failed: {e}") | |
# Try to install nvdiffrast if missing | |
try: | |
print("π Attempting to install nvdiffrast...") | |
import subprocess | |
result = subprocess.run([sys.executable, "-m", "pip", "install", "nvdiffrast"], | |
capture_output=True, text=True, timeout=300) | |
if result.returncode == 0: | |
print("β nvdiffrast installed successfully") | |
import nvdiffrast | |
print("β nvdiffrast now imported successfully") | |
else: | |
print(f"β οΈ nvdiffrast installation failed: {result.stderr}") | |
loop_import_error = f"Critical dependency missing: nvdiffrast - {str(e)}" | |
return False | |
except Exception as install_e: | |
print(f"β οΈ Could not install nvdiffrast: {install_e}") | |
loop_import_error = f"Critical dependency missing: nvdiffrast - {str(e)}" | |
return False | |
try: | |
import pytorch3d | |
print("β pytorch3d imported successfully") | |
except ImportError as e: | |
print(f"β pytorch3d import failed: {e}") | |
# Try to install pytorch3d if missing | |
try: | |
print("π Attempting to install pytorch3d...") | |
import subprocess | |
result = subprocess.run([sys.executable, "-m", "pip", "install", "pytorch3d", "--no-deps"], | |
capture_output=True, text=True, timeout=300) | |
if result.returncode == 0: | |
print("β pytorch3d installed successfully") | |
import pytorch3d | |
print("β pytorch3d now imported successfully") | |
else: | |
print(f"β οΈ pytorch3d installation failed: {result.stderr}") | |
loop_import_error = f"Critical dependency missing: pytorch3d - {str(e)}" | |
return False | |
except Exception as install_e: | |
print(f"β οΈ Could not install pytorch3d: {install_e}") | |
loop_import_error = f"Critical dependency missing: pytorch3d - {str(e)}" | |
return False | |
# Try to import fashion_clip | |
try: | |
from packages.fashion_clip.fashion_clip.fashion_clip import FashionCLIP | |
print("β FashionCLIP imported successfully") | |
except ImportError as e: | |
print(f"β FashionCLIP import failed: {e}") | |
# Try to install FashionCLIP if missing | |
try: | |
print("π Attempting to install FashionCLIP...") | |
if os.path.exists("packages/fashion_clip"): | |
import subprocess | |
result = subprocess.run([sys.executable, "-m", "pip", "install", "-e", "packages/fashion_clip"], | |
capture_output=True, text=True, timeout=300) | |
if result.returncode == 0: | |
print("β FashionCLIP installed successfully") | |
from packages.fashion_clip.fashion_clip.fashion_clip import FashionCLIP | |
print("β FashionCLIP now imported successfully") | |
else: | |
print(f"β οΈ FashionCLIP installation failed: {result.stderr}") | |
loop_import_error = f"Critical dependency missing: FashionCLIP - {str(e)}" | |
return False | |
else: | |
print("β οΈ FashionCLIP directory not found") | |
loop_import_error = f"Critical dependency missing: FashionCLIP - {str(e)}" | |
return False | |
except Exception as install_e: | |
print(f"β οΈ Could not install FashionCLIP: {install_e}") | |
loop_import_error = f"Critical dependency missing: FashionCLIP - {str(e)}" | |
return False | |
# Now try to import the loop module - this is the core processing engine | |
try: | |
from loop import loop as loop_func | |
loop = loop_func | |
print("β Successfully imported loop module - Processing engine ready!") | |
return True | |
except ImportError as e: | |
print(f"β Loop module import failed: {e}") | |
loop_import_error = f"Core processing engine failed to load: {str(e)}" | |
return False | |
except Exception as e: | |
print(f"β Unexpected error importing loop module: {e}") | |
loop_import_error = f"Unexpected error loading processing engine: {str(e)}" | |
return False | |
except ImportError as e: | |
error_msg = f"ImportError: {e}" | |
print(error_msg) | |
if "torchvision" in str(e) or "torch" in str(e): | |
loop_import_error = "PyTorch/torchvision compatibility issue detected. The processing engine could not be loaded." | |
else: | |
loop_import_error = f"Missing dependencies: {str(e)}" | |
return False | |
except RuntimeError as e: | |
error_msg = f"RuntimeError: {e}" | |
print(error_msg) | |
if "operator torchvision::nms does not exist" in str(e): | |
loop_import_error = "PyTorch/torchvision version incompatibility. This is a known issue in some environments." | |
else: | |
loop_import_error = f"Runtime error during import: {str(e)}" | |
return False | |
except Exception as e: | |
error_msg = f"Unexpected error: {e}" | |
print(error_msg) | |
loop_import_error = f"Unexpected error during import: {str(e)}" | |
return False | |
# Try to import the loop module | |
print("Attempting to import processing engine...") | |
# First, try to run post-install if needed | |
print("π Checking for Hugging Face Spaces environment...") | |
# More robust environment detection for Hugging Face Spaces | |
is_hf_spaces = ( | |
os.environ.get('HUGGING_FACE_SPACES', '0') == '1' or | |
os.environ.get('SPACE_ID') is not None or | |
os.environ.get('HF_SPACE_ID') is not None or | |
os.environ.get('SPACES_SDK_VERSION') is not None or | |
'huggingface' in os.environ.get('HOSTNAME', '').lower() or | |
os.path.exists('/home/user/app') # Common HF Spaces path | |
) | |
print(f"Environment variables: HUGGING_FACE_SPACES={os.environ.get('HUGGING_FACE_SPACES', 'not set')}") | |
print(f"Environment variables: SPACE_ID={os.environ.get('SPACE_ID', 'not set')}") | |
print(f"Environment variables: HF_SPACE_ID={os.environ.get('HF_SPACE_ID', 'not set')}") | |
print(f"Environment variables: SPACES_SDK_VERSION={os.environ.get('SPACES_SDK_VERSION', 'not set')}") | |
print(f"Hostname: {os.environ.get('HOSTNAME', 'not set')}") | |
print(f"Current working directory: {os.getcwd()}") | |
if is_hf_spaces: | |
print("π Hugging Face Spaces detected - ensuring all dependencies are installed...") | |
try: | |
# Check if critical dependencies are missing | |
missing_deps = [] | |
try: | |
import nvdiffrast | |
print("β nvdiffrast available") | |
except ImportError: | |
missing_deps.append("nvdiffrast") | |
try: | |
import pytorch3d | |
print("β pytorch3d available") | |
except ImportError: | |
missing_deps.append("pytorch3d") | |
try: | |
from packages.fashion_clip.fashion_clip.fashion_clip import FashionCLIP | |
print("β FashionCLIP available") | |
except ImportError: | |
missing_deps.append("FashionCLIP") | |
if missing_deps: | |
print(f"β οΈ Missing dependencies detected: {', '.join(missing_deps)}") | |
print("π Attempting to install missing dependencies...") | |
# Try to run post_install script | |
try: | |
import subprocess | |
print("π¦ Running post-install script...") | |
# Check if post_install.py exists | |
if os.path.exists("post_install.py"): | |
print("β post_install.py found, executing...") | |
result = subprocess.run([sys.executable, "post_install.py"], | |
capture_output=True, text=True, timeout=600) | |
if result.returncode == 0: | |
print("β Post-install script completed successfully") | |
print("Output:", result.stdout) | |
else: | |
print(f"β οΈ Post-install script failed with return code {result.returncode}") | |
print(f"Error output: {result.stderr}") | |
print(f"Standard output: {result.stdout}") | |
else: | |
print("β οΈ post_install.py not found, attempting manual installation...") | |
# Manual installation of critical dependencies | |
try: | |
print("π¦ Installing nvdiffrast...") | |
# Try different installation methods for nvdiffrast | |
result = subprocess.run([sys.executable, "-m", "pip", "install", "nvdiffrast"], | |
capture_output=True, text=True, timeout=300) | |
if result.returncode == 0: | |
print("β nvdiffrast installed successfully") | |
else: | |
print(f"β οΈ nvdiffrast installation failed: {result.stderr}") | |
# Try alternative installation | |
print("π Trying alternative nvdiffrast installation...") | |
result = subprocess.run([sys.executable, "-m", "pip", "install", "nvdiffrast", "--no-cache-dir"], | |
capture_output=True, text=True, timeout=300) | |
if result.returncode == 0: | |
print("β nvdiffrast installed successfully (alternative method)") | |
else: | |
print(f"β οΈ Alternative nvdiffrast installation also failed: {result.stderr}") | |
print("π¦ Installing pytorch3d...") | |
# Try different installation methods for pytorch3d | |
result = subprocess.run([sys.executable, "-m", "pip", "install", "pytorch3d", "--no-deps"], | |
capture_output=True, text=True, timeout=300) | |
if result.returncode == 0: | |
print("β pytorch3d installed successfully") | |
else: | |
print(f"β οΈ pytorch3d installation failed: {result.stderr}") | |
# Try alternative installation | |
print("π Trying alternative pytorch3d installation...") | |
result = subprocess.run([sys.executable, "-m", "pip", "install", "pytorch3d", "--no-deps", "--no-cache-dir"], | |
capture_output=True, text=True, timeout=300) | |
if result.returncode == 0: | |
print("β pytorch3d installed successfully (alternative method)") | |
else: | |
print(f"β οΈ Alternative pytorch3d installation also failed: {result.stderr}") | |
print("π¦ Installing FashionCLIP...") | |
if os.path.exists("packages/fashion_clip"): | |
result = subprocess.run([sys.executable, "-m", "pip", "install", "-e", "packages/fashion_clip"], | |
capture_output=True, text=True, timeout=300) | |
if result.returncode == 0: | |
print("β FashionCLIP installed successfully") | |
else: | |
print(f"β οΈ FashionCLIP installation failed: {result.stderr}") | |
else: | |
print("β οΈ FashionCLIP directory not found") | |
print("β Manual installation completed") | |
# Re-check dependencies after installation | |
print("π Re-checking dependencies after installation...") | |
try: | |
import nvdiffrast | |
print("β nvdiffrast now available") | |
except ImportError: | |
print("β οΈ nvdiffrast still not available") | |
try: | |
import pytorch3d | |
print("β pytorch3d now available") | |
except ImportError: | |
print("β οΈ pytorch3d still not available") | |
try: | |
from packages.fashion_clip.fashion_clip.fashion_clip import FashionCLIP | |
print("β FashionCLIP now available") | |
except ImportError: | |
print("β οΈ FashionCLIP still not available") | |
except Exception as e: | |
print(f"β οΈ Manual installation failed: {e}") | |
except Exception as e: | |
print(f"β οΈ Could not run post-install script: {e}") | |
else: | |
print("β All critical dependencies are available") | |
except Exception as e: | |
print(f"β οΈ Error checking dependencies: {e}") | |
else: | |
print("π Local development environment detected") | |
# Try to import the loop module after dependency installation attempts | |
print("π Attempting to import processing engine after dependency checks...") | |
import_success = try_import_loop() | |
if import_success: | |
print("β Processing engine loaded successfully") | |
print("π― Production mode: All critical dependencies are available!") | |
else: | |
print(f"β Processing engine failed to load: {loop_import_error}") | |
print("β CRITICAL ERROR: Cannot start in production mode without core dependencies.") | |
print("Please ensure all required dependencies are installed:") | |
print(" - nvdiffrast") | |
print(" - pytorch3d") | |
print(" - FashionCLIP") | |
print(" - torchvision (with compatibility fixes)") | |
print(" - All other required packages") | |
# In production, we should fail fast if critical dependencies are missing | |
if is_hf_spaces: | |
print("π¨ Production environment detected - exiting due to missing dependencies") | |
print("π‘ Try running: python post_install.py") | |
print("π‘ Or check the logs above for installation errors") | |
print("π‘ You may need to restart the application after dependencies are installed") | |
sys.exit(1) | |
else: | |
print("β οΈ Development mode - continuing with limited functionality") | |
# Ensure NeuralJacobianFields is properly configured | |
try: | |
# Check if PoissonSystem.py needs to be modified to disable torch-sparse | |
poisson_system_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), | |
"NeuralJacobianFields", "PoissonSystem.py") | |
if os.path.exists(poisson_system_path): | |
with open(poisson_system_path, 'r') as f: | |
content = f.read() | |
if "USE_TORCH_SPARSE = True" in content: | |
print("Disabling torch-sparse in PoissonSystem.py") | |
content = content.replace("USE_TORCH_SPARSE = True", "USE_TORCH_SPARSE = False") | |
with open(poisson_system_path, 'w') as f: | |
f.write(content) | |
print("Successfully disabled torch-sparse in PoissonSystem.py") | |
except Exception as e: | |
print(f"Warning: Could not check/modify NeuralJacobianFields configuration: {e}") | |
# Continue execution, as this is not fatal | |
# Global variables for configuration | |
DEFAULT_CONFIG = { | |
'output_path': './outputs', | |
'gpu': 0, | |
'seed': 99, | |
'clip_model': 'ViT-B/32', | |
'consistency_clip_model': 'ViT-B/32', | |
'consistency_vit_stride': 8, | |
'consistency_vit_layer': 11, | |
'mesh': os.path.join(os.path.dirname(os.path.abspath(__file__)), 'meshes', 'longsleeve.obj'), | |
'target_mesh': os.path.join(os.path.dirname(os.path.abspath(__file__)), 'meshes_target', 'jacket_sdf_new.obj'), | |
'retriangulate': 0, | |
'bsdf': 'diffuse', | |
'lr': 0.0025, | |
'epochs': 1800, | |
'clip_weight': 2.5, | |
'delta_clip_weight': 5, | |
'vgg_weight': 0.0, | |
'face_weight': 0, | |
'regularize_jacobians_weight': 0.15, | |
'consistency_loss_weight': 0, | |
'consistency_elev_filter': 30, | |
'consistency_azim_filter': 20, | |
'batch_size': 24, | |
'train_res': 512, | |
'resize_method': 'cubic', | |
'fov_min': 30.0, | |
'fov_max': 90.0, | |
'dist_min': 2.5, | |
'dist_max': 3.5, | |
'light_power': 5.0, | |
'elev_alpha': 1.0, | |
'elev_beta': 5.0, | |
'elev_max': 60.0, | |
'azim_alpha': 1.0, | |
'azim_beta': 5.0, | |
'azim_min': 0.0, | |
'azim_max': 360.0, | |
'aug_loc': 1, | |
'aug_light': 1, | |
'aug_bkg': 0, | |
'adapt_dist': 1, | |
'log_interval': 5, | |
'log_interval_im': 150, | |
'log_elev': 0, | |
'log_fov': 60.0, | |
'log_dist': 3.0, | |
'log_res': 512, | |
'log_light_power': 3.0 | |
} | |
def process_garment(input_type, text_prompt, base_text_prompt, mesh_target_image, source_mesh_type, custom_mesh, epochs, learning_rate, clip_weight, delta_clip_weight, progress=gr.Progress()): | |
""" | |
Main function to process garment generation | |
Args: | |
input_type: Either "Text" or "Image to Mesh" to determine the processing mode | |
text_prompt: Text description of target garment (for text mode) | |
base_text_prompt: Text description of base garment (for text mode) | |
mesh_target_image: Image for generating a 3D mesh (for image to mesh mode) | |
source_mesh_type: Type of source mesh to use as starting point (for image to mesh mode) | |
custom_mesh: Optional custom source mesh file (.obj) | |
epochs: Number of optimization epochs | |
learning_rate: Optimization learning rate | |
clip_weight: Weight for CLIP loss | |
delta_clip_weight: Weight for delta CLIP loss | |
progress: Gradio progress tracking object | |
""" | |
try: | |
# Create a temporary output directory | |
with tempfile.TemporaryDirectory() as temp_dir: | |
# Update configuration | |
config = DEFAULT_CONFIG.copy() | |
# Set up input parameters based on mode | |
if input_type == "Image to Mesh": | |
if mesh_target_image is None: | |
return "Error: Please upload an image for Image to Mesh mode." | |
# Image-to-Mesh processing | |
progress(0.05, desc="Preparing mesh generation from image...") | |
# Save target image to temp directory | |
target_mesh_image_path = os.path.join(temp_dir, "target_mesh_image.jpg") | |
try: | |
if isinstance(mesh_target_image, str): | |
shutil.copy(mesh_target_image, target_mesh_image_path) | |
elif isinstance(mesh_target_image, np.ndarray): | |
# Ensure the array is in the correct format | |
if len(mesh_target_image.shape) == 3: | |
if mesh_target_image.shape[2] == 4: # RGBA | |
mesh_target_image = mesh_target_image[:,:,:3] # Convert to RGB | |
img = Image.fromarray(mesh_target_image.astype(np.uint8)) | |
img.save(target_mesh_image_path) | |
else: | |
return "Error: Invalid image format. Please upload a valid RGB image." | |
elif hasattr(mesh_target_image, 'save'): | |
mesh_target_image.save(target_mesh_image_path) | |
else: | |
print(f"Unsupported image type: {type(mesh_target_image)}") | |
return "Error: Could not process the uploaded image. Please try a different image format." | |
print(f"Target mesh image saved to {target_mesh_image_path}") | |
# Set mesh paths based on selected source mesh type | |
# Map display names to actual file names | |
mesh_mapping = { | |
"tshirt": "tshirt", | |
"longsleeve": "longsleeve", | |
"tanktop": "tanktop", | |
"poncho": "poncho", | |
"dress_shortsleeve": "dress_shortsleeve" | |
} | |
mesh_file = mesh_mapping.get(source_mesh_type, "tshirt") | |
# Use absolute paths for mesh files | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
source_mesh_file = os.path.join(current_dir, "meshes", f"{mesh_file}.obj") | |
# Check if the mesh file exists | |
if not os.path.exists(source_mesh_file): | |
return f"Error: Mesh file {source_mesh_file} not found. Please check if the mesh files are available." | |
print(f"Using source mesh: {source_mesh_file}") | |
# Configure for image-to-mesh processing | |
config.update({ | |
'mesh': source_mesh_file, | |
'image_prompt': target_mesh_image_path, | |
'base_image_prompt': target_mesh_image_path, # Use same image as base | |
'use_target_mesh': True, | |
'fashion_image': True, | |
'fashion_text': False, | |
}) | |
except Exception as e: | |
print(f"Error processing image: {e}") | |
return f"Error: Failed to process the uploaded image: {str(e)}" | |
else: | |
# Text-based processing | |
if not text_prompt or len(text_prompt.strip()) == 0: | |
return "Error: Text prompt is required for text-based generation." | |
if not base_text_prompt or len(base_text_prompt.strip()) == 0: | |
base_text_prompt = "simple t-shirt" # Default base prompt | |
config.update({ | |
'text_prompt': text_prompt, | |
'base_text_prompt': base_text_prompt, | |
'fashion_image': False, | |
'fashion_text': True | |
}) | |
# Handle custom mesh if provided | |
if custom_mesh is not None: | |
custom_mesh_path = os.path.join(temp_dir, "custom_mesh.obj") | |
shutil.copy(custom_mesh, custom_mesh_path) | |
config['mesh'] = custom_mesh_path | |
# Update optimization parameters | |
config.update({ | |
'output_path': temp_dir, | |
'epochs': int(epochs), | |
'lr': float(learning_rate), | |
'clip_weight': float(clip_weight), | |
'delta_clip_weight': float(delta_clip_weight), | |
'gpu': 0 # Use first GPU | |
}) | |
# Set random seeds | |
random.seed(config['seed']) | |
os.environ['PYTHONHASHSEED'] = str(config['seed']) | |
np.random.seed(config['seed']) | |
torch.manual_seed(config['seed']) | |
torch.cuda.manual_seed(config['seed']) | |
torch.backends.cudnn.deterministic = True | |
progress(0.1, desc="Initializing...") | |
# Print configuration for debugging | |
print("Starting processing with configuration:") | |
print(f"Mode: {'Image' if config.get('fashion_image', False) else 'Text'}") | |
if config.get('fashion_image', False): | |
print(f"Target image: {config['image_prompt']}") | |
print(f"Base image: {config['base_image_prompt']}") | |
else: | |
print(f"Target text: {config['text_prompt']}") | |
print(f"Base text: {config['base_text_prompt']}") | |
# Run the main processing loop | |
progress(0.2, desc="Running garment generation...") | |
try: | |
# Check if loop is available (should always be available in production) | |
if loop is None: | |
error_message = "Error: Processing engine not available. Please check dependencies." | |
print(error_message) | |
return error_message | |
# Run the loop with error handling | |
try: | |
print("π Starting garment generation with real processing engine...") | |
print(f"Configuration: {config}") | |
# Validate mesh files before processing | |
if 'mesh' in config and config['mesh']: | |
mesh_path = config['mesh'] | |
if not os.path.exists(mesh_path): | |
error_message = f"Error: Source mesh file not found: {mesh_path}" | |
print(error_message) | |
return error_message | |
# Check if mesh file is valid | |
try: | |
import pymeshlab | |
ms = pymeshlab.MeshSet() | |
ms.load_new_mesh(mesh_path) | |
if ms.current_mesh().vertex_number() == 0: | |
error_message = f"Error: Source mesh file has no vertices: {mesh_path}" | |
print(error_message) | |
return error_message | |
print(f"β Source mesh validated: {ms.current_mesh().vertex_number()} vertices, {ms.current_mesh().face_number()} faces") | |
except Exception as mesh_e: | |
print(f"Warning: Could not validate mesh file: {mesh_e}") | |
loop(config) | |
print("β Garment generation completed successfully!") | |
except ValueError as ve: | |
print(f"β Validation error during garment generation: {ve}") | |
if "no vertices" in str(ve).lower() or "no faces" in str(ve).lower(): | |
error_message = f"Error: Invalid mesh data detected. The source mesh appears to be corrupted or empty. Please try a different mesh file." | |
elif "jacobian" in str(ve).lower(): | |
error_message = f"Error: Jacobian computation failed. This may indicate an issue with the mesh structure or processing pipeline." | |
elif "index" in str(ve).lower() and "bounds" in str(ve).lower(): | |
error_message = f"Error: Mesh processing failed due to invalid data structure. This may indicate corrupted mesh files or processing errors." | |
else: | |
error_message = f"Error during processing: {str(ve)}" | |
return error_message | |
except FileNotFoundError as fe: | |
print(f"β File not found error during garment generation: {fe}") | |
if "mesh" in str(fe).lower(): | |
error_message = f"Error: Required mesh file not found during processing. This may indicate an issue with the mesh loading pipeline." | |
elif "mtl" in str(fe).lower(): | |
error_message = f"Error: Material file not found. This may indicate an issue with the mesh file structure." | |
else: | |
error_message = f"Error: Required file not found during processing: {str(fe)}" | |
return error_message | |
except Exception as e: | |
print(f"β Error during garment generation: {e}") | |
import traceback | |
traceback.print_exc() | |
# Provide more specific error messages based on error type | |
if "nvdiffrast" in str(e).lower(): | |
error_message = "Error: Rendering engine (nvdiffrast) failed. This may be due to OpenGL/EGL compatibility issues." | |
elif "clip" in str(e).lower(): | |
error_message = "Error: CLIP model failed to load or process. This may be due to model availability or compatibility issues." | |
elif "cuda" in str(e).lower() or "gpu" in str(e).lower(): | |
error_message = "Error: GPU/CUDA processing failed. This may be due to hardware compatibility or driver issues." | |
elif "memory" in str(e).lower(): | |
error_message = "Error: Insufficient memory during processing. Try reducing the number of epochs or using a smaller mesh." | |
else: | |
error_message = f"Error during processing: {str(e)}" | |
return error_message | |
except RuntimeError as e: | |
print(f"Runtime error during processing: {e}") | |
if "operator torchvision::nms does not exist" in str(e): | |
error_message = "Error: PyTorch/torchvision version incompatibility detected. This is a known issue in some environments." | |
print(error_message) | |
return error_message | |
else: | |
error_message = f"Runtime error during processing: {str(e)}" | |
print(error_message) | |
return error_message | |
except Exception as e: | |
print(f"Error during processing: {e}") | |
error_message = f"Error during processing: {str(e)}" | |
print(error_message) | |
return error_message | |
progress(0.9, desc="Processing complete, preparing output...") | |
# Look for output files and ensure both OBJ and GLB formats are available | |
obj_files = [] | |
glb_files = [] | |
image_files = [] | |
print("Searching for output files and ensuring GLB conversion...") | |
# First check for mesh files in mesh_final directory (priority) | |
mesh_final_dir = Path(temp_dir) / "mesh_final" | |
if mesh_final_dir.exists(): | |
print(f"Found mesh_final directory at {mesh_final_dir}") | |
# Ensure both OBJ and GLB formats are available | |
obj_files, glb_files = ensure_mesh_formats(mesh_final_dir) | |
print(f"Found {len(obj_files)} OBJ files and {len(glb_files)} GLB files in mesh_final") | |
else: | |
print("mesh_final directory not found") | |
# Check other mesh directories | |
for mesh_dir in Path(temp_dir).glob("mesh_*"): | |
if mesh_dir.is_dir() and mesh_dir.name != 'mesh_final': | |
print(f"Checking directory: {mesh_dir}") | |
dir_obj_files, dir_glb_files = ensure_mesh_formats(mesh_dir) | |
obj_files.extend(dir_obj_files) | |
glb_files.extend(dir_glb_files) | |
# Collect image files for visualization | |
for file_path in Path(temp_dir).rglob("*"): | |
if file_path.is_file() and file_path.suffix.lower() in ['.png', '.jpg', '.jpeg', '.gif', '.mp4']: | |
image_files.append(str(file_path)) | |
print(f"Found {len(glb_files)} GLB files, {len(obj_files)} OBJ files, and {len(image_files)} image files") | |
# Prioritize output: GLB, OBJ, then images | |
if glb_files: | |
print(f"Returning GLB file: {glb_files[0]}") | |
return glb_files[0] # Return first GLB file (best for web viewing) | |
elif obj_files: | |
print(f"Returning OBJ file: {obj_files[0]}") | |
return obj_files[0] # Return first OBJ file | |
elif image_files: | |
print(f"Returning image file: {image_files[0]}") | |
return image_files[0] # Return an image if no mesh was found | |
else: | |
print("No output files found") | |
return None | |
except Exception as e: | |
import traceback | |
error_details = traceback.format_exc() | |
print(f"Error during processing: {str(e)}") | |
print(f"Error details: {error_details}") | |
# Return None instead of an error string to avoid file not found errors with Gradio | |
return None | |
def create_combined_mesh_output(output_dir): | |
""" | |
Create a combined output showing both OBJ and GLB files if available | |
Args: | |
output_dir: Directory containing mesh files | |
Returns: | |
tuple: (primary_file, secondary_file, status_message) | |
""" | |
obj_files, glb_files = ensure_mesh_formats(output_dir) | |
if glb_files and obj_files: | |
# Both formats available - return GLB as primary (better for web viewing) | |
return glb_files[0], obj_files[0], "π Success! Both GLB and OBJ files generated. GLB file is displayed (better for web viewing), OBJ file is also available." | |
elif glb_files: | |
return glb_files[0], None, "π Success! GLB file generated and ready for download." | |
elif obj_files: | |
return obj_files[0], None, "π Success! OBJ file generated and ready for download." | |
else: | |
return None, None, "β No mesh files were generated. Please check the processing logs." | |
def create_interface(): | |
""" | |
Create the Gradio interface with simplified components | |
""" | |
with gr.Blocks(title="Garment3DGen - 3D Garment Stylization") as interface: | |
gr.Markdown(""" | |
# Garment3DGen: 3D Garment Stylization and Texture Generation | |
This tool allows you to stylize 3D garments using text prompts or images. Generate a new 3D garment mesh (.obj/.glb) | |
that can be used for virtual try-on applications. | |
## How to use: | |
1. Choose **Text** or **Image to Mesh** input mode using the radio button below | |
2. For **Text** mode: Enter descriptions of your target and base garment styles | |
3. For **Image to Mesh** mode: Upload an image to generate a 3D mesh directly and select a base mesh type | |
4. Click "Generate 3D Garment" to create your 3D mesh file | |
5. **GLB files** are automatically generated for better web viewing and virtual try-on compatibility | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Input type selector | |
input_type = gr.Radio( | |
choices=["Text", "Image to Mesh"], | |
value="Text", | |
label="Generation Method", | |
interactive=True | |
) | |
# Text inputs (visible by default) | |
with gr.Group(visible=True) as text_group: | |
text_prompt = gr.Textbox( | |
label="Target Text Prompt", | |
placeholder="e.g., leather jacket with studs", | |
value="leather jacket with studs" | |
) | |
base_text_prompt = gr.Textbox( | |
label="Base Text Prompt", | |
placeholder="e.g., simple t-shirt", | |
value="simple t-shirt" | |
) | |
# Image to Mesh inputs (hidden by default) | |
with gr.Group(visible=False) as image_to_mesh_group: | |
gr.Markdown("### πΈ Upload Garment Image") | |
mesh_target_image = gr.Image( | |
label="Target Garment Image for Mesh Generation", | |
sources=["upload", "clipboard", "webcam"], | |
type="numpy", | |
interactive=True, | |
height=300, | |
show_label=True | |
) | |
gr.Markdown("*Upload an image of the garment to convert directly to a 3D mesh*") | |
gr.Markdown("### π― Select Base Mesh Type") | |
source_mesh_type = gr.Dropdown( | |
label="Source Mesh Type", | |
choices=["tshirt", "longsleeve", "tanktop", "poncho", "dress_shortsleeve"], | |
value="tshirt", | |
interactive=True | |
) | |
gr.Markdown("*Select the type of base garment mesh to use as a starting point*") | |
# Custom mesh | |
custom_mesh = gr.File( | |
label="Custom Source Mesh (Optional)", | |
file_types=[".obj"] | |
) | |
# Simple parameters | |
epochs = gr.Slider( | |
minimum=100, | |
maximum=3000, | |
value=1800, | |
step=100, | |
label="Number of Epochs" | |
) | |
learning_rate = gr.Slider( | |
minimum=0.0001, | |
maximum=0.01, | |
value=0.0025, | |
step=0.0001, | |
label="Learning Rate" | |
) | |
clip_weight = gr.Slider( | |
minimum=0.1, | |
maximum=10.0, | |
value=2.5, | |
step=0.1, | |
label="CLIP Weight" | |
) | |
delta_clip_weight = gr.Slider( | |
minimum=0.1, | |
maximum=20.0, | |
value=5.0, | |
step=0.1, | |
label="Delta CLIP Weight" | |
) | |
generate_btn = gr.Button("Generate 3D Garment") | |
with gr.Column(): | |
# Primary output (GLB preferred) | |
output = gr.File( | |
label="Generated 3D Garment (GLB/OBJ)", | |
file_types=[".obj", ".glb", ".png", ".jpg"], | |
file_count="single" | |
) | |
# Secondary output (OBJ if GLB is primary) | |
secondary_output = gr.File( | |
label="Alternative Format (OBJ/GLB)", | |
file_types=[".obj", ".glb"], | |
file_count="single", | |
visible=False | |
) | |
gr.Markdown(""" | |
## Tips: | |
- For text mode: Be specific in your descriptions (e.g., "red leather jacket with zippers") | |
- For image to mesh mode: Use clear, front-facing garment images to generate a 3D mesh directly | |
- Choose the appropriate base mesh type that matches your target garment | |
- Higher epochs = better quality but longer processing time | |
- **GLB files** are automatically generated for better web viewing and virtual try-on compatibility | |
- **OBJ files** are also available for traditional 3D software compatibility | |
- Output files can be downloaded by clicking on them | |
Processing may take several minutes. | |
""") | |
# Add a status output for errors and messages | |
if loop is not None: | |
engine_status = "β Processing engine loaded successfully - Production Ready!" | |
status_msg = "π― Ready to generate garments! Select an input method and click 'Generate 3D Garment'." | |
else: | |
engine_status = f"β Processing engine unavailable: {loop_import_error or 'Unknown error'}" | |
status_msg = "β CRITICAL ERROR: Processing engine failed to load. Please check that all dependencies are properly installed." | |
engine_status_output = gr.Markdown(f"**System Status:** {engine_status}") | |
status_output = gr.Markdown(status_msg) | |
# Define a function to handle mode changes with clearer UI feedback | |
def update_mode(mode): | |
print(f"Mode changed to: {mode}") | |
text_visibility = mode == "Text" | |
image_to_mesh_visibility = mode == "Image to Mesh" | |
status_msg = f"Mode changed to {mode}. " | |
if text_visibility: | |
status_msg += "Enter garment descriptions and click Generate." | |
else: | |
status_msg += "Upload a garment image and select mesh type, then click Generate." | |
print(f"Text visibility: {text_visibility}, Image to Mesh visibility: {image_to_mesh_visibility}") | |
print(f"Returning updates: text_group={text_visibility}, image_to_mesh_group={image_to_mesh_visibility}") | |
return ( | |
gr.Group.update(visible=text_visibility), | |
gr.Group.update(visible=image_to_mesh_visibility), | |
status_msg | |
) | |
# Function to handle processing with better error feedback and dual output | |
def process_with_feedback(*args): | |
try: | |
# Check if processing engine is available | |
if loop is None: | |
return None, None, "β ERROR: Processing engine not available. Please check that all dependencies are properly installed." | |
result = process_garment(*args) | |
if result is None: | |
return None, None, "Processing completed but no output files were generated. Please check the logs for more details." | |
elif isinstance(result, str) and result.startswith("Error:"): | |
# Return None for the file outputs and the error message for status | |
return None, None, result | |
elif isinstance(result, str) and os.path.exists(result): | |
# Valid file path - check if we can create a combined output | |
result_path = Path(result) | |
if result_path.suffix.lower() == '.glb': | |
# GLB file - try to find corresponding OBJ | |
obj_file = result_path.with_suffix('.obj') | |
if obj_file.exists(): | |
return result, str(obj_file), "π Success! Both GLB and OBJ files generated. GLB file is displayed (better for web viewing), OBJ file is also available." | |
else: | |
return result, None, "π Success! GLB file generated and ready for download." | |
elif result_path.suffix.lower() == '.obj': | |
# OBJ file - try to find corresponding GLB or create one | |
glb_file = result_path.with_suffix('.glb') | |
if glb_file.exists(): | |
return str(glb_file), result, "π Success! Both GLB and OBJ files generated. GLB file is displayed (better for web viewing), OBJ file is also available." | |
else: | |
# Try to convert OBJ to GLB | |
glb_path = convert_obj_to_glb(result) | |
if glb_path: | |
return glb_path, result, "π Success! Both GLB and OBJ files generated. GLB file is displayed (better for web viewing), OBJ file is also available." | |
else: | |
return result, None, "π Success! OBJ file generated and ready for download." | |
else: | |
# Some other file type | |
return result, None, "π Processing completed successfully! Download your file below." | |
elif isinstance(result, str): | |
# Some other string that's not an error and not a file path | |
return None, None, f"Unexpected result: {result}" | |
else: | |
# Should be a file path or None | |
return result, None, "π Processing completed successfully! Download your 3D garment file below." | |
except Exception as e: | |
import traceback | |
print(f"Error in interface: {str(e)}") | |
print(traceback.format_exc()) | |
return None, None, f"β Error: {str(e)}" | |
# Toggle visibility based on input mode with better feedback | |
input_type.change( | |
fn=update_mode, | |
inputs=[input_type], | |
outputs=[text_group, image_to_mesh_group, status_output], | |
show_progress=True | |
) | |
# Connect the button to the processing function with error handling and dual output | |
generate_btn.click( | |
fn=process_with_feedback, | |
inputs=[ | |
input_type, | |
text_prompt, | |
base_text_prompt, | |
mesh_target_image, | |
source_mesh_type, | |
custom_mesh, | |
epochs, | |
learning_rate, | |
clip_weight, | |
delta_clip_weight | |
], | |
outputs=[output, secondary_output, status_output] | |
) | |
# Update secondary output visibility when primary output changes | |
def update_secondary_visibility(primary_file): | |
"""Update secondary output visibility based on whether both formats are available""" | |
if primary_file is not None and primary_file != "": | |
# Check if there's a corresponding file in the other format | |
primary_path = Path(primary_file) | |
if primary_path.suffix.lower() == '.glb': | |
# Check if corresponding OBJ exists | |
obj_file = primary_path.with_suffix('.obj') | |
if obj_file.exists(): | |
return gr.update(visible=True) | |
elif primary_path.suffix.lower() == '.obj': | |
# Check if corresponding GLB exists | |
glb_file = primary_path.with_suffix('.glb') | |
if glb_file.exists(): | |
return gr.update(visible=True) | |
return gr.update(visible=False) | |
# Connect the secondary output visibility to the primary output | |
output.change( | |
fn=update_secondary_visibility, | |
inputs=[output], | |
outputs=[secondary_output] | |
) | |
return interface | |
if __name__ == "__main__": | |
print("Starting Garment3DGen application...") | |
# Apply final torchvision fixes before launching | |
try: | |
apply_torchvision_fix() | |
print("Final torchvision compatibility check completed") | |
except Exception as e: | |
print(f"Warning: Could not apply final torchvision fixes: {e}") | |
# Create and launch the interface | |
try: | |
interface = create_interface() | |
print("Gradio interface created successfully") | |
# Launch with error handling | |
interface.launch( | |
share=False, | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True, | |
quiet=False, | |
debug=True | |
) | |
except Exception as e: | |
print(f"Error launching interface: {e}") | |
import traceback | |
print("Full error traceback:") | |
print(traceback.format_exc()) | |
# Provide helpful error messages | |
if "torchvision" in str(e) or "operator" in str(e): | |
print("\n" + "="*80) | |
print("CRITICAL ERROR: PyTorch/torchvision compatibility issue detected.") | |
print("This is a known issue in some environments.") | |
print("The error occurred during interface launch.") | |
print("="*80 + "\n") | |
elif "loop" in str(e) or "dependencies" in str(e): | |
print("\n" + "="*80) | |
print("DEPENDENCY ERROR: Required modules could not be loaded.") | |
print("Check that all dependencies are properly installed.") | |
print("="*80 + "\n") | |
else: | |
print("\n" + "="*80) | |
print("UNKNOWN ERROR: An unexpected error occurred.") | |
print("Please check the logs above for more details.") | |
print("="*80 + "\n") |