File size: 23,205 Bytes
0f41ba2
1908f03
b149af8
3951932
6b78472
 
 
1908f03
6b78472
0f41ba2
 
 
 
 
 
 
 
 
 
 
 
 
c05134c
eec0975
776d5b3
2fc2bf3
 
9dbb75f
0f41ba2
 
 
 
 
 
 
 
 
 
 
c05134c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eec0975
776d5b3
eec0975
776d5b3
eec0975
 
 
 
776d5b3
eec0975
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f41ba2
 
 
 
 
 
 
 
2fc2bf3
0f41ba2
2fc2bf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f41ba2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fc2bf3
0f41ba2
2fc2bf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f41ba2
 
2fc2bf3
0f41ba2
 
2fc2bf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f41ba2
 
 
 
 
 
 
 
 
fcc9ef6
0f41ba2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcc9ef6
0f41ba2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
776d5b3
0f41ba2
49f568d
 
 
 
776d5b3
 
d7e594b
0f41ba2
b266eca
 
 
 
 
 
 
3c7a85f
f6e0da0
d7e594b
 
 
 
 
 
f6e0da0
 
 
 
 
 
 
 
 
0f41ba2
b266eca
 
 
 
 
 
 
 
 
 
3c7a85f
b266eca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f41ba2
b266eca
 
 
 
 
 
0f41ba2
 
 
 
49f568d
 
776d5b3
 
0f41ba2
 
b266eca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c7a85f
b266eca
 
 
 
 
 
 
 
 
 
3c7a85f
b266eca
 
 
 
 
 
 
 
 
 
 
 
0f41ba2
b266eca
 
 
 
 
0f41ba2
b266eca
 
0f41ba2
 
 
b266eca
 
0f41ba2
b266eca
0f41ba2
b266eca
 
 
 
 
 
0f41ba2
 
49f568d
0f41ba2
49f568d
 
0f41ba2
 
49f568d
0f41ba2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b266eca
0f41ba2
 
 
 
 
 
b266eca
 
0f41ba2
 
 
b266eca
 
0f41ba2
 
 
 
 
 
 
 
 
 
 
a0a4b6f
0f41ba2
b266eca
 
 
 
 
 
 
 
 
 
0f41ba2
f6e0da0
0f41ba2
b266eca
 
 
 
12b3742
 
b266eca
 
 
 
 
 
 
 
0f41ba2
 
 
 
b266eca
0f41ba2
 
 
 
 
b266eca
0f41ba2
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
import os
from typing import Union
# this is a HF Spaces specific hack for ZeroGPU
import spaces

import sys
import torch
from shap_e.models.transmitter.base import Transmitter, VectorDecoder

import torch
import torch.nn as nn
import gradio as gr
import numpy as np
from PIL import Image
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from huggingface_hub import hf_hub_download
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
from einops import rearrange
from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config
from shap_e.util.notebooks import create_pan_cameras
from shap_e.models.nn.camera import DifferentiableCameraBatch, DifferentiableProjectiveCamera
import math
import time
from requests.exceptions import ReadTimeout, ConnectionError
from shap_e.util.collections import AttrDict

from src.utils.train_util import instantiate_from_config
from src.utils.camera_util import (
    FOV_to_intrinsics, 
    get_zero123plus_input_cameras,
    get_circular_camera_poses,
    spherical_camera_pose
)
from src.utils.mesh_util import save_obj, save_glb
from src.utils.infer_util import remove_background, resize_foreground

def decode_latent_images(
    xm: Union[Transmitter, VectorDecoder],
    latent: torch.Tensor,
    cameras: DifferentiableCameraBatch,
    rendering_mode: str = "stf",
    params = None,
    background_color: torch.Tensor = torch.tensor([255.0, 255.0, 255.0], dtype=torch.float32),
):
    params = params if params is not None else (xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params(
            latent[None]
        )
    params = xm.renderer.update(params)
    decoded = xm.renderer.render_views(
        AttrDict(cameras=cameras),
        params=params,
        options=AttrDict(rendering_mode=rendering_mode, render_with_direction=False),
    )
    bg_color = background_color.to(decoded.channels.device)
    images = bg_color * decoded.transmittance + (1 - decoded.transmittance) * decoded.channels

    # arr = decoded.channels.clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
    return images

def create_custom_cameras(size: int, device: torch.device, azimuths: list, elevations: list, 
                          fov_degrees: float, distance: float) -> DifferentiableCameraBatch:
    # Object is in a 2x2x2 bounding box (-1 to 1 in each dimension)
    object_diagonal = distance # Correct diagonal calculation for the cube
    
    # Calculate radius based on object size and FOV
    fov_radians = math.radians(fov_degrees)
    radius = (object_diagonal / 2) / math.tan(fov_radians / 2)  # Correct radius calculation
    
    origins = []
    xs = []
    ys = []
    zs = []
    
    for azimuth, elevation in zip(azimuths, elevations):
        azimuth_rad = np.radians(azimuth-90)
        elevation_rad = np.radians(elevation)
        
        # Calculate camera position
        x = radius * np.cos(elevation_rad) * np.cos(azimuth_rad)
        y = radius * np.cos(elevation_rad) * np.sin(azimuth_rad)
        z = radius * np.sin(elevation_rad)
        origin = np.array([x, y, z])
        
        # Calculate camera orientation
        z_axis = -origin / np.linalg.norm(origin)  # Point towards center
        x_axis = np.array([-np.sin(azimuth_rad), np.cos(azimuth_rad), 0])
        y_axis = np.cross(z_axis, x_axis)
        
        origins.append(origin)
        zs.append(z_axis)
        xs.append(x_axis)
        ys.append(y_axis)

    return DifferentiableCameraBatch(
        shape=(1, len(origins)),
        flat_camera=DifferentiableProjectiveCamera(
            origin=torch.from_numpy(np.stack(origins, axis=0)).float().to(device),
            x=torch.from_numpy(np.stack(xs, axis=0)).float().to(device),
            y=torch.from_numpy(np.stack(ys, axis=0)).float().to(device),
            z=torch.from_numpy(np.stack(zs, axis=0)).float().to(device),
            width=size,
            height=size,
            x_fov=fov_radians,
            y_fov=fov_radians,
        ),
    )

def load_models():
    """Initialize and load all required models"""
    config = OmegaConf.load('configs/instant-nerf-large-best.yaml')
    model_config = config.model_config
    infer_config = config.infer_config

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load diffusion pipeline with retry logic
    print('Loading diffusion pipeline...')
    max_retries = 3
    retry_delay = 5
    
    for attempt in range(max_retries):
        try:
            pipeline = DiffusionPipeline.from_pretrained(
                "sudo-ai/zero123plus-v1.2",
                custom_pipeline="zero123plus",
                torch_dtype=torch.float16,
                local_files_only=False,
                resume_download=True,
            )
            break
        except (ReadTimeout, ConnectionError) as e:
            if attempt == max_retries - 1:
                raise Exception(f"Failed to download pipeline after {max_retries} attempts: {str(e)}")
            print(f"Download attempt {attempt + 1} failed, retrying in {retry_delay} seconds...")
            time.sleep(retry_delay)
            retry_delay *= 2  # Exponential backoff
    
    pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
        pipeline.scheduler.config, timestep_spacing='trailing'
    )

    # Modify UNet to handle 8 input channels instead of 4
    in_channels = 8
    out_channels = pipeline.unet.conv_in.out_channels
    pipeline.unet.register_to_config(in_channels=in_channels)
    with torch.no_grad():
        new_conv_in = nn.Conv2d(
            in_channels, out_channels, pipeline.unet.conv_in.kernel_size, 
            pipeline.unet.conv_in.stride, pipeline.unet.conv_in.padding
        )
        new_conv_in.weight.zero_()
        new_conv_in.weight[:, :4, :, :].copy_(pipeline.unet.conv_in.weight)
        pipeline.unet.conv_in = new_conv_in

    # Load custom UNet with retry logic
    print('Loading custom UNet...')
    for attempt in range(max_retries):
        try:
            pipeline.unet = pipeline.unet.from_pretrained(
                "YiftachEde/Sharp-It",
                local_files_only=False,
                resume_download=True,
            ).to(torch.float16)
            break
        except (ReadTimeout, ConnectionError) as e:
            if attempt == max_retries - 1:
                raise Exception(f"Failed to download UNet after {max_retries} attempts: {str(e)}")
            print(f"Download attempt {attempt + 1} failed, retrying in {retry_delay} seconds...")
            time.sleep(retry_delay)
            retry_delay *= 2

    pipeline = pipeline.to(device).to(torch_dtype=torch.float16)

    # Load reconstruction model with retry logic
    print('Loading reconstruction model...')
    model = instantiate_from_config(model_config)
    
    for attempt in range(max_retries):
        try:
            model_path = hf_hub_download(
                repo_id="TencentARC/InstantMesh",
                filename="instant_nerf_large.ckpt",
                repo_type="model",
                local_files_only=False,
                resume_download=True,
                cache_dir="model_cache"  # Use a specific cache directory
            )
            break
        except (ReadTimeout, ConnectionError) as e:
            if attempt == max_retries - 1:
                raise Exception(f"Failed to download model after {max_retries} attempts: {str(e)}")
            print(f"Download attempt {attempt + 1} failed, retrying in {retry_delay} seconds...")
            time.sleep(retry_delay)
            retry_delay *= 2

    state_dict = torch.load(model_path, map_location='cpu')['state_dict']
    state_dict = {k[14:]: v for k, v in state_dict.items() 
                 if k.startswith('lrm_generator.') and 'source_camera' not in k}
    model.load_state_dict(state_dict, strict=True)
    model = model.to(device)
    model.eval()
    
    return pipeline, model, infer_config

@spaces.GPU(duration=20)
def process_images(input_images, prompt, steps=75, guidance_scale=7.5, pipeline=None):
    """Process input images and run refinement"""
    device = pipeline.device
    
    if isinstance(input_images, list):
        if len(input_images) == 1:
            # Check if this is a pre-arranged layout
            img = Image.open(input_images[0].name).convert('RGB')
            if img.size == (640, 960):
                # This is already a layout, use it directly
                input_image = img
            else:
                # Single view - need 6 copies
                img = img.resize((320, 320))
                img_array = np.array(img) / 255.0
                images = [img_array] * 6
                images = np.stack(images)
                
                # Convert to tensor and create layout
                images = torch.from_numpy(images).float()
                images = images.permute(0, 3, 1, 2)
                images = images.reshape(3, 2, 3, 320, 320)
                images = images.permute(0, 2, 3, 1, 4)
                images = images.reshape(3, 3, 320, 640)
                images = images.reshape(1, 3, 960, 640)
                
                # Convert back to PIL
                images = images.permute(0, 2, 3, 1)[0]
                images = (images.numpy() * 255).astype(np.uint8)
                input_image = Image.fromarray(images)
        else:
            # Multiple individual views
            images = []
            for img_file in input_images:
                img = Image.open(img_file.name).convert('RGB')
                img = img.resize((320, 320))
                img = np.array(img) / 255.0
                images.append(img)
            
            # Pad to 6 images if needed
            while len(images) < 6:
                images.append(np.zeros_like(images[0]))
            images = np.stack(images[:6])
            
            # Convert to tensor and create layout
            images = torch.from_numpy(images).float()
            images = images.permute(0, 3, 1, 2)
            images = images.reshape(3, 2, 3, 320, 320)
            images = images.permute(0, 2, 3, 1, 4)
            images = images.reshape(3, 3, 320, 640)
            images = images.reshape(1, 3, 960, 640)
            
            # Convert back to PIL
            images = images.permute(0, 2, 3, 1)[0]
            images = (images.numpy() * 255).astype(np.uint8)
            input_image = Image.fromarray(images)
    else:
        raise ValueError("Expected a list of images")

    # Generate refined output
    output = pipeline.refine(
        input_image,
        prompt=prompt,
        num_inference_steps=int(steps),
        guidance_scale=guidance_scale
    ).images[0]
    
    return output, input_image

@spaces.GPU(duration=20)
def create_mesh(refined_image, model, infer_config):
    """Generate mesh from refined image"""
    # Convert PIL image to tensor
    image = np.array(refined_image) / 255.0
    image = torch.from_numpy(image).float().permute(2, 0, 1)
    
    # Reshape to 6 views
    image = image.reshape(3, 960, 640)
    image = image.reshape(3, 3, 320, 640)
    image = image.permute(1, 0, 2, 3)
    image = image.reshape(3, 3, 320, 2, 320)
    image = image.permute(0, 3, 1, 2, 4)
    image = image.reshape(6, 3, 320, 320)
    
    # Add batch dimension
    image = image.unsqueeze(0)
    
    input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to("cuda")
    image = image.to("cuda")
    
    with torch.no_grad():
        planes = model.forward_planes(image, input_cameras)
        mesh_out = model.extract_mesh(planes, **infer_config)
        vertices, faces, vertex_colors = mesh_out
        
    return vertices, faces, vertex_colors

class ShapERenderer:
    def __init__(self, device):
        print("Initializing Shap-E models...")
        self.device = device
        torch.cuda.empty_cache()  # Clear GPU memory before loading
        self.xm = load_model('transmitter', device=self.device)
        self.model = load_model('text300M', device=self.device)
        self.diffusion = diffusion_from_config(load_config('diffusion'))
        print("Shap-E models initialized!")
    
    @spaces.GPU(duration=80)  
    def generate_views(self, prompt, guidance_scale=15.0, num_steps=64):
        try:
            torch.cuda.empty_cache()  # Clear GPU memory before generation
            
            # Generate latents using the text-to-3D model
            batch_size = 1
            guidance_scale = float(guidance_scale)
            
            with torch.amp.autocast('cuda'):  # Use automatic mixed precision
                # Generate latents directly without nested spaces.GPU context
                latents = sample_latents(
                    batch_size=batch_size,
                    model=self.model,
                    diffusion=self.diffusion,
                    guidance_scale=guidance_scale,
                    model_kwargs=dict(texts=[prompt] * batch_size),
                    progress=True,
                    clip_denoised=True,
                    use_fp16=True,
                    use_karras=True,
                    karras_steps=num_steps,
                    sigma_min=1e-3,
                    sigma_max=160,
                    s_churn=0,
                )

            # Render the 6 views we need with specific viewing angles
            size = 320  # Size of each rendered image
            images = []
            
            # Define our 6 specific camera positions to match refine.py
            azimuths = [30, 90, 150, 210, 270, 330]
            elevations = [20, -10, 20, -10, 20, -10]
            
            for i, (azimuth, elevation) in enumerate(zip(azimuths, elevations)):
                cameras = create_custom_cameras(size, self.device, azimuths=[azimuth], elevations=[elevation], fov_degrees=30, distance=3.0)
                with torch.amp.autocast('cuda'):  # Use automatic mixed precision
                    rendered_image = decode_latent_images(
                        self.xm,
                        latents[0],
                        cameras=cameras,
                        rendering_mode='stf'
                    )
                images.append(rendered_image[0])
                torch.cuda.empty_cache()  # Clear GPU memory after each view
            
            # Convert images to uint8
            images = [np.array(image) for image in images]
            
            # Create 2x3 grid layout (640x960)
            layout = np.zeros((960, 640, 3), dtype=np.uint8)
            for i, img in enumerate(images):
                row = i // 2
                col = i % 2
                layout[row*320:(row+1)*320, col*320:(col+1)*320] = img

            return Image.fromarray(layout), images
            
        except Exception as e:
            print(f"Error in generate_views: {e}")
            torch.cuda.empty_cache()  # Clear GPU memory on error
            raise

class RefinerInterface:
    def __init__(self):
        print("Initializing InstantMesh models...")
        torch.cuda.empty_cache()  # Clear GPU memory before loading
        self.pipeline, self.model, self.infer_config = load_models()
        print("InstantMesh models initialized!")
    
    def refine_model(self, input_image, prompt, steps=75, guidance_scale=7.5):
        """Main refinement function"""
        try:
            torch.cuda.empty_cache()  # Clear GPU memory before processing
            
            # Process image and get refined output
            input_image = Image.fromarray(input_image)
            
            # Rotate the layout if needed (if we're getting a 640x960 layout but pipeline expects 960x640)
            if input_image.width == 960 and input_image.height == 640:
                # Transpose the image to get 960x640 layout
                input_array = np.array(input_image)
                new_layout = np.zeros((960, 640, 3), dtype=np.uint8)
                
                # Rearrange from 2x3 to 3x2
                for i in range(6):
                    src_row = i // 3
                    src_col = i % 3
                    dst_row = i // 2
                    dst_col = i % 2
                    
                    new_layout[dst_row*320:(dst_row+1)*320, dst_col*320:(dst_col+1)*320] = \
                        input_array[src_row*320:(src_row+1)*320, src_col*320:(src_col+1)*320]
                
                input_image = Image.fromarray(new_layout)
            
            # Process with the pipeline (expects 960x640)
            with torch.amp.autocast('cuda'):  # Use automatic mixed precision
                refined_output_960x640 = self.pipeline.refine(
                    input_image,
                    prompt=prompt,
                    num_inference_steps=int(steps),
                    guidance_scale=guidance_scale
                ).images[0]
            
            torch.cuda.empty_cache()  # Clear GPU memory after refinement
            
            # Generate mesh using the 960x640 format
            with torch.amp.autocast('cuda'):  # Use automatic mixed precision
                vertices, faces, vertex_colors = create_mesh(
                    refined_output_960x640, 
                    self.model, 
                    self.infer_config
                )
            
            torch.cuda.empty_cache()  # Clear GPU memory after mesh generation
            
            # Save temporary mesh file
            os.makedirs("temp", exist_ok=True)
            temp_obj = os.path.join("temp", "refined_mesh.obj")
            save_obj(vertices, faces, vertex_colors, temp_obj)
            
            # Convert the output to 640x960 for display
            refined_array = np.array(refined_output_960x640)
            display_layout = np.zeros((960, 640, 3), dtype=np.uint8)
            
            # Rearrange from 3x2 to 2x3
            for i in range(6):
                src_row = i // 2
                src_col = i % 2
                dst_row = i // 2
                dst_col = i % 2
                
                display_layout[dst_row*320:(dst_row+1)*320, dst_col*320:(dst_col+1)*320] = \
                    refined_array[src_row*320:(src_row+1)*320, src_col*320:(src_col+1)*320]
            
            refined_output_640x960 = Image.fromarray(display_layout)
            
            return refined_output_640x960, temp_obj
            
        except Exception as e:
            print(f"Error in refine_model: {e}")
            torch.cuda.empty_cache()  # Clear GPU memory on error
            raise

def create_demo():
    print("Initializing models...")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Initialize models at startup
    shap_e = ShapERenderer(device)
    refiner = RefinerInterface()
    print("All models initialized!")
    
    with gr.Blocks() as demo:
        gr.Markdown("# Shap-E to InstantMesh Pipeline")
        
        # First row: Controls
        with gr.Row():
            with gr.Column():
                # Shap-E inputs
                shape_prompt = gr.Textbox(
                    label="Shap-E Prompt", 
                    placeholder="Enter text to generate initial 3D model..."
                )
                shape_guidance = gr.Slider(
                    minimum=1, 
                    maximum=30, 
                    value=15.0, 
                    label="Shap-E Guidance Scale"
                )
                shape_steps = gr.Slider(
                    minimum=16, 
                    maximum=128, 
                    value=64, 
                    step=16, 
                    label="Shap-E Steps"
                )
                generate_btn = gr.Button("Generate Views")
            
            with gr.Column():
                # Refinement inputs
                refine_prompt = gr.Textbox(
                    label="Refinement Prompt", 
                    placeholder="Enter prompt to guide refinement..."
                )
                refine_steps = gr.Slider(
                    minimum=30,
                    maximum=100,
                    value=75,
                    step=1,
                    label="Refinement Steps"
                )
                refine_guidance = gr.Slider(
                    minimum=1,
                    maximum=20,
                    value=7.5,
                    label="Refinement Guidance Scale"
                )
                refine_btn = gr.Button("Refine")
                error_output = gr.Textbox(label="Status/Error Messages", interactive=False)

        # Second row: Image panels side by side
        with gr.Row():
            # Outputs - Images side by side
            shape_output = gr.Image(
                label="Generated Views", 
                width=640,
                height=960
            )
            refined_output = gr.Image(
                label="Refined Output",
                width=640,
                height=960
            )
        
        # Third row: 3D mesh panel below
        with gr.Row():
            # 3D mesh centered
            mesh_output = gr.Model3D(
                label="3D Mesh", 
                clear_color=[1.0, 1.0, 1.0, 1.0],
            )

        # Set up event handlers
        @spaces.GPU(duration=100)  # Add GPU decorator to the generate function
        def generate(prompt, guidance_scale, num_steps):
            try:
                torch.cuda.empty_cache()  # Clear GPU memory before starting
                with torch.no_grad():
                    layout, _ = shap_e.generate_views(prompt, guidance_scale, num_steps)
                return layout, None  # Return None for error message
            except Exception as e:
                torch.cuda.empty_cache()  # Clear GPU memory on error
                error_msg = f"Error during generation: {str(e)}"
                print(error_msg)
                return None, error_msg

        @spaces.GPU(duration=20)
        def refine(input_image, prompt, steps, guidance_scale):
            try:
                torch.cuda.empty_cache()  # Clear GPU memory before starting
                refined_img, mesh_path = refiner.refine_model(
                    input_image, 
                    prompt,
                    steps,
                    guidance_scale
                )
                return refined_img, mesh_path, None  # Return None for error message
            except Exception as e:
                torch.cuda.empty_cache()  # Clear GPU memory on error
                error_msg = f"Error during refinement: {str(e)}"
                print(error_msg)
                return None, None, error_msg

        generate_btn.click(
            fn=generate,
            inputs=[shape_prompt, shape_guidance, shape_steps],
            outputs=[shape_output, error_output]
        )

        refine_btn.click(
            fn=refine,
            inputs=[shape_output, refine_prompt, refine_steps, refine_guidance],
            outputs=[refined_output, mesh_output, error_output]
        )

    return demo

if __name__ == "__main__":
    demo = create_demo()
    demo.launch(share=True)