3D-Viewer-AI / app.py
ruslanmv's picture
Update app.py
3f92458
raw
history blame
13.7 kB
import gradio as gr
import plotly.graph_objs as go
import trimesh
import numpy as np
from PIL import Image, ImageDraw
import torch
from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline
import io
import matplotlib.pyplot as plt
#import pyrender
#import scipy
import csv
import sys
import os
# Load the Stable Diffusion model for text-to-image generation and inpainting
device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(device)
# Adjusted to handle device compatibility
if device == "cuda":
pipeline_inpaint = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
torch_dtype=torch.float16
).to(device)
else:
pipeline_inpaint = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting"
).to(device)
# Get the current working directory
CURRENT_DIR = os.getcwd()
# Define example file paths using the current directory
DEFAULT_OBJ_FILE = os.path.join(CURRENT_DIR, "female.obj")
DEFAULT_GLB_FILE = os.path.join(CURRENT_DIR, "vroid_girl1.glb")
DEFAULT_VRM_FILE = os.path.join(CURRENT_DIR, "fischl.vrm")
DEFAULT_VRM_FILE2 = os.path.join(CURRENT_DIR, "woman.vrm")
DEFAULT_VRM_FILE3 = os.path.join(CURRENT_DIR, "mona.vrm")
DEFAULT_TEXTURE = os.path.join(CURRENT_DIR, "future.png")
DEFAULT_TEXTURE2 = os.path.join(CURRENT_DIR, "woman1.jpeg")
DEFAULT_TEXTURE3 = os.path.join(CURRENT_DIR, "woman2.jpeg")
# Ensure all example files exist
example_files = [
[DEFAULT_VRM_FILE, DEFAULT_TEXTURE],
[DEFAULT_OBJ_FILE, None],
[DEFAULT_GLB_FILE, None],
[DEFAULT_VRM_FILE2, DEFAULT_TEXTURE2],
[DEFAULT_VRM_FILE3, DEFAULT_TEXTURE3]
]
for example in example_files:
for file in example:
if file and not os.path.exists(file):
print(f"Warning: Example file {file} does not exist!")
def generate_default_uv(mesh, quality='medium'):
"""
Generate default UV coordinates for a mesh if UV mapping is missing.
"""
if quality == 'low':
bounds = mesh.bounds
width = bounds[1][0] - bounds[0][0]
height = bounds[1][1] - bounds[0][1]
uv_coords = np.zeros((len(mesh.vertices), 2))
uv_coords[:, 0] = (mesh.vertices[:, 0] - bounds[0][0]) / width
uv_coords[:, 1] = (mesh.vertices[:, 1] - bounds[0][1]) / height
elif quality == 'medium':
height_range = mesh.vertices[:, 2].max() - mesh.vertices[:, 2].min()
radius = np.sqrt(mesh.vertices[:, 0]**2 + mesh.vertices[:, 1]**2)
uv_coords = np.zeros((len(mesh.vertices), 2))
uv_coords[:, 0] = np.arctan2(mesh.vertices[:, 1], mesh.vertices[:, 0]) / (2 * np.pi) + 0.5
uv_coords[:, 1] = (mesh.vertices[:, 2] - mesh.vertices[:, 2].min()) / height_range
elif quality == 'high':
radius = np.sqrt(np.sum(mesh.vertices**2, axis=1))
uv_coords = np.zeros((len(mesh.vertices), 2))
uv_coords[:, 0] = np.arctan2(mesh.vertices[:, 1], mesh.vertices[:, 0]) / (2 * np.pi) + 0.5
uv_coords[:, 1] = np.arccos(mesh.vertices[:, 2] / radius) / np.pi
else:
raise ValueError("Invalid quality parameter. Choose from 'low', 'medium', or 'high'.")
return uv_coords
def apply_texture(mesh, texture_image, uv_scale, uv_quality='medium'):
"""
Applies the texture to the mesh with UV scaling.
"""
if not hasattr(mesh.visual, 'uv') or mesh.visual.uv is None:
# If the mesh does not have UV coordinates, generate them
print("No UV coordinates found; generating default UV mapping.")
uv_coords = generate_default_uv(mesh, quality=uv_quality)
else:
uv_coords = mesh.visual.uv
# Ensure UV coordinates exist
if uv_coords is None:
raise ValueError("UV coordinates are missing from the mesh.")
# Apply UV scaling and ensure it is within valid range
uv_coords = np.clip(uv_coords * uv_scale, 0, 1)
img_width, img_height = texture_image.size
texture_array = np.array(texture_image)
face_colors = []
for face in mesh.faces:
uv_face = uv_coords[face]
pixel_coords = np.round(uv_face * np.array([img_width - 1, img_height - 1])).astype(int)
valid_coords = np.all((pixel_coords[:, 0] >= 0) & (pixel_coords[:, 0] < img_width) &
(pixel_coords[:, 1] >= 0) & (pixel_coords[:, 1] < img_height))
if valid_coords:
face_color = np.mean(texture_array[pixel_coords[:, 1], pixel_coords[:, 0]], axis=0)
face_colors.append(face_color / 255.0)
else:
face_colors.append([0.5, 0.5, 0.5])
face_colors = np.array(face_colors)
if len(face_colors) < len(mesh.faces):
face_colors = np.pad(face_colors, ((0, len(mesh.faces) - len(face_colors)), (0, 0)), 'constant', constant_values=0.5)
return face_colors
def load_glb_file(filename):
trimesh_scene = trimesh.load(filename)
if isinstance(trimesh_scene, trimesh.Scene):
mesh = trimesh_scene.dump(concatenate=True)
else:
mesh = trimesh_scene
return mesh
def generate_clothing_image(prompt, num_inference_steps):
"""
Generates the clothing texture based on the provided prompt and number of inference steps.
"""
image = pipeline(prompt, num_inference_steps=num_inference_steps).images[0]
return image
def load_vrm_file(filename):
try:
vrm_data = trimesh.load(filename, file_type='glb')
if isinstance(vrm_data, trimesh.Scene):
mesh = vrm_data.dump(concatenate=True)
else:
mesh = vrm_data
except Exception as e:
raise ValueError(f"Failed to load VRM file: {e}")
return mesh
def display_3d_object(obj_file, texture_image, light_intensity, ambient_intensity, color, uv_scale, transparency, uv_quality=None):
file_extension = obj_file.split('.')[-1].lower()
if file_extension == 'vrm':
mesh = load_vrm_file(obj_file)
try:
if texture_image:
face_colors = apply_texture(mesh, texture_image, uv_scale, uv_quality)
else:
face_colors = np.array([color] * len(mesh.faces))
except ValueError as e:
face_colors = np.array([color] * len(mesh.faces))
vertices = mesh.vertices
faces = mesh.faces
fig = go.Figure(data=[
go.Mesh3d(
x=vertices[:, 0],
y=vertices[:, 1],
z=vertices[:, 2],
i=faces[:, 0],
j=faces[:, 1],
k=faces[:, 2],
facecolor=face_colors,
opacity=transparency,
lighting=dict(
ambient=ambient_intensity,
diffuse=light_intensity,
specular=0.8,
roughness=0.3,
fresnel=0.1
),
lightposition=dict(
x=100,
y=200,
z=300
)
)
])
fig.update_layout(scene=dict(aspectmode='data'))
return fig
else:
if file_extension == 'obj':
mesh = trimesh.load(obj_file)
elif file_extension == 'glb':
mesh = load_glb_file(obj_file)
else:
raise ValueError("Unsupported file format. Please upload a .obj, .glb, or .vrm file.")
if texture_image:
face_colors = apply_texture(mesh, texture_image, uv_scale)
else:
face_colors = np.array([color] * len(mesh.faces))
fig = go.Figure(data=[
go.Mesh3d(
x=mesh.vertices[:, 0],
y=mesh.vertices[:, 1],
z=mesh.vertices[:, 2],
i=mesh.faces[:, 0],
j=mesh.faces[:, 1],
k=mesh.faces[:, 2],
facecolor=face_colors,
opacity=transparency,
lighting=dict(
ambient=ambient_intensity,
diffuse=light_intensity,
specular=0.8,
roughness=0.3,
fresnel=0.1
),
lightposition=dict(
x=100,
y=200,
z=300
)
)
])
fig.update_layout(scene=dict(aspectmode='data'))
return fig
def clear_texture():
return None
def restore_original(obj_file):
return display_3d_object(obj_file, None, 0.8, 0.5, "#D3D3D3", 1.0, 1.0)
def update_texture_display(prompt, texture_file, num_inference_steps):
if prompt:
image = generate_clothing_image(prompt, num_inference_steps)
return image
elif texture_file:
return Image.open(texture_file)
return None
def load_example(obj_file, texture_file):
"""Loads and displays an example 3D object with texture."""
file_extension = obj_file.split('.')[-1].lower()
texture_image = None
if texture_file:
texture_image = Image.open(texture_file)
if file_extension == 'vrm':
return display_3d_object(obj_file, texture_image, 0.8, 0.5, "#D3D3D3", 1.0, 1.0, 'medium') # Using default values for other parameters
else:
return display_3d_object(obj_file, texture_image, 0.8, 0.5, "#D3D3D3", 1.0, 1.0) # Using default values for other parameters
with gr.Blocks() as demo:
gr.Markdown("## 3D Object Viewer with Custom Texture, UV Scale, Transparency, Color, and Adjustable Lighting")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Texture Options")
prompt_input = gr.Textbox(label="Enter a Prompt to Generate Texture", placeholder="Type a prompt...")
num_inference_steps_slider = gr.Slider(minimum=5, maximum=100, step=1, value=10, label="Num Inference Steps")
generate_button = gr.Button("Generate Texture")
texture_file = gr.File(label="Upload Texture file (PNG or JPG, optional)", type="filepath")
texture_preview = gr.Image(label="Texture Preview", visible=True)
gr.Markdown("### Mapping, Lighting & Color Settings")
uv_scale_slider = gr.Slider(minimum=0.1, maximum=5, step=0.1, value=1.0, label="UV Mapping Scale")
uv_quality_dropdown = gr.Dropdown(label="UV Quality (for VRM files)", choices=['low', 'medium', 'high'], value='medium')
light_intensity_slider = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.8, label="Light Intensity")
ambient_intensity_slider = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5, label="Ambient Intensity")
transparency_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=1.0, label="Transparency")
color_picker = gr.ColorPicker(value="#D3D3D3", label="Object Color")
submit_button = gr.Button("Submit")
restore_button = gr.Button("Restore")
clear_button = gr.Button("Clear")
obj_file = gr.File(label="Upload OBJ, GLB, or VRM file", value=DEFAULT_OBJ_FILE, type='filepath')
with gr.Column(scale=2):
display = gr.Plot(label="3D Viewer")
def update_display(file, texture, uv_scale, uv_quality, light_intensity, ambient_intensity, transparency, color, num_inference_steps):
file_extension = file.split('.')[-1].lower()
texture_image = None
if texture:
texture_image = Image.open(texture)
if file_extension == 'vrm':
return display_3d_object(file, texture_image, light_intensity, ambient_intensity, color, uv_scale, transparency, uv_quality)
else:
return display_3d_object(file, texture_image, light_intensity, ambient_intensity, color, uv_scale, transparency)
def toggle_uv_quality_dropdown(file):
if file is None:
return gr.update(visible=False)
file_extension = file.split('.')[-1].lower()
return gr.update(visible=(file_extension == 'vrm'))
submit_button.click(
fn=update_display,
inputs=[obj_file, texture_file, uv_scale_slider, uv_quality_dropdown, light_intensity_slider, ambient_intensity_slider, transparency_slider, color_picker, num_inference_steps_slider],
outputs=display
)
obj_file.change(fn=toggle_uv_quality_dropdown, inputs=[obj_file], outputs=uv_quality_dropdown)
generate_button.click(fn=update_texture_display, inputs=[prompt_input, texture_file, num_inference_steps_slider], outputs=texture_preview)
restore_button.click(fn=restore_original, inputs=[obj_file], outputs=display)
clear_button.click(fn=clear_texture, outputs=texture_preview)
texture_file.change(fn=update_texture_display, inputs=[prompt_input, texture_file, num_inference_steps_slider], outputs=texture_preview)
demo.load(fn=update_display, inputs=[obj_file, texture_file, uv_scale_slider, uv_quality_dropdown, light_intensity_slider, ambient_intensity_slider, transparency_slider, color_picker, num_inference_steps_slider], outputs=display)
gr.Examples(
examples=[
[DEFAULT_VRM_FILE, DEFAULT_TEXTURE],
[DEFAULT_OBJ_FILE, None],
[DEFAULT_GLB_FILE, None],
[DEFAULT_VRM_FILE2, DEFAULT_TEXTURE2],
[DEFAULT_VRM_FILE3, DEFAULT_TEXTURE3]
],
inputs=[obj_file, texture_file],
outputs=display, # Specify the output component
fn=load_example, # Specify the function to load the example
label="Example Files",
cache_examples=False # Disable caching
)
demo.launch(debug=True)