Nikhil0987 commited on
Commit
ebf16ac
1 Parent(s): aedd5aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -25
app.py CHANGED
@@ -1,35 +1,186 @@
1
- import pygame
2
- from pygame.locals import *
 
 
 
 
 
 
3
 
4
- # Initialize Pygame
5
- pygame.init()
 
6
 
7
- # Load the Talking Tom face image
8
- face_model = pygame.image.load('pictures/talking_tom_face.png')
 
 
9
 
10
- # Create a display surface
11
- screen = pygame.display.set_mode((640, 480), 0, 32)
12
 
13
- # Main game loop
14
- running = True
15
- while running:
16
- # Handle events
17
- for event in pygame.event.get():
18
- if event.type == QUIT:
19
- running = False
20
 
21
- # --- Drawing Logic ---
 
22
 
23
- # Clear the screen (fill with a background color)
24
- screen.fill((255, 255, 255)) # Example: Fill with white
25
 
26
- # Draw the Talking Tom's face
27
- screen.blit(face_model, (100, 100))
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- # --- End Drawing Logic ---
 
30
 
31
- # Update the display
32
- pygame.display.flip()
 
 
 
 
 
 
 
 
33
 
34
- # Quit Pygame
35
- pygame.quit()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torchvision.transforms.functional as TF
6
+ from safetensors.torch import load_file
7
+ import rembg
8
+ import gradio as gr
9
 
10
+ # download checkpoints
11
+ from huggingface_hub import hf_hub_download
12
+ ckpt_path = hf_hub_download(repo_id="ashawkey/LGM", filename="model_fp16.safetensors")
13
 
14
+ try:
15
+ import diff_gaussian_rasterization
16
+ except ImportError:
17
+ os.system("pip install ./diff-gaussian-rasterization")
18
 
19
+ import kiui
20
+ from kiui.op import recenter
21
 
22
+ from core.options import Options
23
+ from core.models import LGM
24
+ from mvdream.pipeline_mvdream import MVDreamPipeline
 
 
 
 
25
 
26
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
27
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
28
 
29
+ TMP_DIR = '/tmp'
30
+ os.makedirs(TMP_DIR, exist_ok=True)
31
 
32
+ # opt = tyro.cli(AllConfigs)
33
+ opt = Options(
34
+ input_size=256,
35
+ up_channels=(1024, 1024, 512, 256, 128), # one more decoder
36
+ up_attention=(True, True, True, False, False),
37
+ splat_size=128,
38
+ output_size=512, # render & supervise Gaussians at a higher resolution.
39
+ batch_size=8,
40
+ num_views=8,
41
+ gradient_accumulation_steps=1,
42
+ mixed_precision='bf16',
43
+ resume=ckpt_path,
44
+ )
45
 
46
+ # model
47
+ model = LGM(opt)
48
 
49
+ # resume pretrained checkpoint
50
+ if opt.resume is not None:
51
+ if opt.resume.endswith('safetensors'):
52
+ ckpt = load_file(opt.resume, device='cpu')
53
+ else:
54
+ ckpt = torch.load(opt.resume, map_location='cpu')
55
+ model.load_state_dict(ckpt, strict=False)
56
+ print(f'[INFO] Loaded checkpoint from {opt.resume}')
57
+ else:
58
+ print(f'[WARN] model randomly initialized, are you sure?')
59
 
60
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
61
+ model = model.half().to(device)
62
+ model.eval()
63
+
64
+ tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
65
+ proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device)
66
+ proj_matrix[0, 0] = -1 / tan_half_fov
67
+ proj_matrix[1, 1] = -1 / tan_half_fov
68
+ proj_matrix[2, 2] = - (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
69
+ proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
70
+ proj_matrix[2, 3] = 1
71
+
72
+ # load dreams
73
+ pipe_text = MVDreamPipeline.from_pretrained(
74
+ 'ashawkey/mvdream-sd2.1-diffusers', # remote weights
75
+ torch_dtype=torch.float16,
76
+ trust_remote_code=True,
77
+ # local_files_only=True,
78
+ )
79
+ pipe_text = pipe_text.to(device)
80
+
81
+ pipe_image = MVDreamPipeline.from_pretrained(
82
+ "ashawkey/imagedream-ipmv-diffusers", # remote weights
83
+ torch_dtype=torch.float16,
84
+ trust_remote_code=True,
85
+ # local_files_only=True,
86
+ )
87
+ pipe_image = pipe_image.to(device)
88
+
89
+ # load rembg
90
+ bg_remover = rembg.new_session()
91
+
92
+ # process function
93
+ def run(input_image):
94
+ prompt_neg = "ugly, blurry, pixelated obscure, unnatural colors, poor lighting, dull, unclear, cropped, lowres, low quality, artifacts, duplicate"
95
+
96
+ # seed
97
+ kiui.seed_everything(42)
98
+
99
+ output_ply_path = os.path.join(TMP_DIR, 'output.ply')
100
+
101
+ input_image = np.array(input_image) # uint8
102
+ # bg removal
103
+ carved_image = rembg.remove(input_image, session=bg_remover) # [H, W, 4]
104
+ mask = carved_image[..., -1] > 0
105
+ image = recenter(carved_image, mask, border_ratio=0.2)
106
+ image = image.astype(np.float32) / 255.0
107
+ image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
108
+ mv_image = pipe_image("", image, negative_prompt=prompt_neg, num_inference_steps=30, guidance_scale=5.0, elevation=0)
109
+
110
+ # generate gaussians
111
+ input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0) # [4, 256, 256, 3], float32
112
+ input_image = torch.from_numpy(input_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
113
+ input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
114
+ input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
115
+
116
+ rays_embeddings = model.prepare_default_rays(device, elevation=0)
117
+ input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W]
118
+
119
+ with torch.no_grad():
120
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
121
+ # generate gaussians
122
+ gaussians = model.forward_gaussians(input_image)
123
+
124
+ # save gaussians
125
+ model.gs.save_ply(gaussians, output_ply_path)
126
+
127
+ return output_ply_path
128
+
129
+ # gradio UI
130
+
131
+ _TITLE = '''LGM Mini'''
132
+
133
+ _DESCRIPTION = '''
134
+ <div>
135
+ A lightweight version of <a href="https://huggingface.co/spaces/ashawkey/LGM">LGM: Large Multi-View Gaussian Model for High-Resolution 3D Content Creation</a>.
136
+ </div>
137
+ '''
138
+
139
+ css = '''
140
+ #duplicate-button {
141
+ margin: auto;
142
+ color: white;
143
+ background: #1565c0;
144
+ border-radius: 100vh;
145
+ }
146
+ '''
147
+
148
+ block = gr.Blocks(title=_TITLE, css=css)
149
+ with block:
150
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
151
+
152
+ with gr.Row():
153
+ with gr.Column(scale=1):
154
+ gr.Markdown('# ' + _TITLE)
155
+ gr.Markdown(_DESCRIPTION)
156
+
157
+ with gr.Row(variant='panel'):
158
+ with gr.Column(scale=1):
159
+ # input image
160
+ input_image = gr.Image(label="image", type='pil', height=320)
161
+ # gen button
162
+ button_gen = gr.Button("Generate")
163
+
164
+
165
+ with gr.Column(scale=1):
166
+ output_splat = gr.Model3D(label="3D Gaussians")
167
+
168
+ button_gen.click(fn=run, inputs=[input_image], outputs=[output_splat])
169
+
170
+ gr.Examples(
171
+ examples=[
172
+ "data_test/frog_sweater.jpg",
173
+ "data_test/bird.jpg",
174
+ "data_test/boy.jpg",
175
+ "data_test/cat_statue.jpg",
176
+ "data_test/dragontoy.jpg",
177
+ "data_test/gso_rabbit.jpg",
178
+ ],
179
+ inputs=[input_image],
180
+ outputs=[output_splat],
181
+ fn=lambda x: run(input_image=x),
182
+ cache_examples=True,
183
+ label='Image-to-3D Examples'
184
+ )
185
+
186
+ block.queue().launch(debug=True, share=True)