ginipick commited on
Commit
6fd4f6a
·
verified ·
1 Parent(s): e66bfa0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -258
app.py CHANGED
@@ -1,259 +1,2 @@
1
- import gradio as gr
2
- import spaces
3
- from gradio_litmodel3d import LitModel3D
4
  import os
5
- os.environ['SPCONV_ALGO'] = 'native'
6
- from typing import *
7
- import torch
8
- import numpy as np
9
- import imageio
10
- import uuid
11
- from easydict import EasyDict as edict
12
- from PIL import Image
13
- from trellis.pipelines import TrellisImageTo3DPipeline
14
- from trellis.representations import Gaussian, MeshExtractResult
15
- from trellis.utils import render_utils, postprocessing_utils
16
-
17
- # 기본 설정
18
- MAX_SEED = np.iinfo(np.int32).max
19
- TMP_DIR = "/tmp/Trellis-demo"
20
- os.makedirs(TMP_DIR, exist_ok=True)
21
-
22
- # CUDA 초기화 함수
23
- def init_cuda():
24
- try:
25
- if torch.cuda.is_available():
26
- device = torch.device('cuda')
27
- print("CUDA 초기화 성공")
28
- else:
29
- device = torch.device('cpu')
30
- print("CUDA를 사용할 수 없어 CPU를 사용합니다")
31
- return device
32
- except Exception as e:
33
- print(f"CUDA 초기화 중 오류 발생: {e}")
34
- return torch.device('cpu')
35
-
36
- def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
37
- """
38
- 입력 이미지 전처리
39
- """
40
- trial_id = str(uuid.uuid4())
41
- processed_image = pipeline.preprocess_image(image)
42
- processed_image.save(f"{TMP_DIR}/{trial_id}.png")
43
- return trial_id, processed_image
44
-
45
- def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
46
- """
47
- 상태 정보 패킹
48
- """
49
- return {
50
- 'gaussian': {
51
- **gs.init_params,
52
- '_xyz': gs._xyz.cpu().numpy(),
53
- '_features_dc': gs._features_dc.cpu().numpy(),
54
- '_scaling': gs._scaling.cpu().numpy(),
55
- '_rotation': gs._rotation.cpu().numpy(),
56
- '_opacity': gs._opacity.cpu().numpy(),
57
- },
58
- 'mesh': {
59
- 'vertices': mesh.vertices.cpu().numpy(),
60
- 'faces': mesh.faces.cpu().numpy(),
61
- },
62
- 'trial_id': trial_id,
63
- }
64
-
65
- def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
66
- """
67
- 상태 정보 언패킹
68
- """
69
- device = init_cuda()
70
- gs = Gaussian(
71
- aabb=state['gaussian']['aabb'],
72
- sh_degree=state['gaussian']['sh_degree'],
73
- mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
74
- scaling_bias=state['gaussian']['scaling_bias'],
75
- opacity_bias=state['gaussian']['opacity_bias'],
76
- scaling_activation=state['gaussian']['scaling_activation'],
77
- )
78
- gs._xyz = torch.tensor(state['gaussian']['_xyz'], device=device)
79
- gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device=device)
80
- gs._scaling = torch.tensor(state['gaussian']['_scaling'], device=device)
81
- gs._rotation = torch.tensor(state['gaussian']['_rotation'], device=device)
82
- gs._opacity = torch.tensor(state['gaussian']['_opacity'], device=device)
83
-
84
- mesh = edict(
85
- vertices=torch.tensor(state['mesh']['vertices'], device=device),
86
- faces=torch.tensor(state['mesh']['faces'], device=device),
87
- )
88
-
89
- return gs, mesh, state['trial_id']
90
-
91
- @spaces.GPU
92
- def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_strength: float, ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int) -> Tuple[dict, str]:
93
- """
94
- 이미지를 3D 모델로 변환
95
- """
96
- try:
97
- if randomize_seed:
98
- seed = np.random.randint(0, MAX_SEED)
99
- outputs = pipeline.run(
100
- Image.open(f"{TMP_DIR}/{trial_id}.png"),
101
- seed=seed,
102
- formats=["gaussian", "mesh"],
103
- preprocess_image=False,
104
- sparse_structure_sampler_params={
105
- "steps": ss_sampling_steps,
106
- "cfg_strength": ss_guidance_strength,
107
- },
108
- slat_sampler_params={
109
- "steps": slat_sampling_steps,
110
- "cfg_strength": slat_guidance_strength,
111
- },
112
- )
113
- video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
114
- video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
115
- video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
116
- trial_id = uuid.uuid4()
117
- video_path = f"{TMP_DIR}/{trial_id}.mp4"
118
- os.makedirs(os.path.dirname(video_path), exist_ok=True)
119
- imageio.mimsave(video_path, video, fps=15)
120
- state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
121
- return state, video_path
122
- except Exception as e:
123
- print(f"3D 변환 중 오류 발생: {e}")
124
- return None, None
125
-
126
- @spaces.GPU
127
- def extract_glb(state: dict, mesh_simplify: float, texture_size: int) -> Tuple[str, str]:
128
- """
129
- 3D 모델에서 GLB 파일 추출
130
- """
131
- try:
132
- gs, mesh, trial_id = unpack_state(state)
133
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
134
- glb_path = f"{TMP_DIR}/{trial_id}.glb"
135
- glb.export(glb_path)
136
- return glb_path, glb_path
137
- except Exception as e:
138
- print(f"GLB 추출 중 오류 발생: {e}")
139
- return None, None
140
-
141
- def activate_button() -> gr.Button:
142
- return gr.Button(interactive=True)
143
-
144
- def deactivate_button() -> gr.Button:
145
- return gr.Button(interactive=False)
146
-
147
- # Gradio 인터페이스 설정
148
- css = """
149
- footer {
150
- visibility: hidden;
151
- }
152
- """
153
-
154
- with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
155
- gr.Markdown("""
156
- ## Anything 3D""")
157
-
158
- with gr.Row():
159
- with gr.Column():
160
- image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil", height=300)
161
-
162
- with gr.Accordion(label="Generation Settings", open=False):
163
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
164
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
165
- gr.Markdown("Stage 1: Sparse Structure Generation")
166
- with gr.Row():
167
- ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
168
- ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
169
- gr.Markdown("Stage 2: Structured Latent Generation")
170
- with gr.Row():
171
- slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
172
- slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
173
-
174
- generate_btn = gr.Button("Generate")
175
-
176
- with gr.Accordion(label="GLB Extraction Settings", open=False):
177
- mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
178
- texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
179
-
180
- extract_glb_btn = gr.Button("Extract GLB", interactive=False)
181
-
182
- with gr.Column():
183
- video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
184
- model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
185
- download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
186
-
187
- trial_id = gr.Textbox(visible=False)
188
- output_buf = gr.State()
189
-
190
- # 예제 이미지 설정
191
- with gr.Row():
192
- examples = gr.Examples(
193
- examples=[
194
- f'assets/example_image/{image}'
195
- for image in os.listdir("assets/example_image")
196
- ],
197
- inputs=[image_prompt],
198
- fn=preprocess_image,
199
- outputs=[trial_id, image_prompt],
200
- run_on_click=True,
201
- examples_per_page=64,
202
- )
203
-
204
- # 이벤트 핸들러 설정
205
- image_prompt.upload(
206
- preprocess_image,
207
- inputs=[image_prompt],
208
- outputs=[trial_id, image_prompt],
209
- )
210
- image_prompt.clear(
211
- lambda: '',
212
- outputs=[trial_id],
213
- )
214
-
215
- generate_btn.click(
216
- image_to_3d,
217
- inputs=[trial_id, seed, randomize_seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
218
- outputs=[output_buf, video_output],
219
- ).then(
220
- activate_button,
221
- outputs=[extract_glb_btn],
222
- )
223
-
224
- video_output.clear(
225
- deactivate_button,
226
- outputs=[extract_glb_btn],
227
- )
228
-
229
- extract_glb_btn.click(
230
- extract_glb,
231
- inputs=[output_buf, mesh_simplify, texture_size],
232
- outputs=[model_output, download_glb],
233
- ).then(
234
- activate_button,
235
- outputs=[download_glb],
236
- )
237
-
238
- model_output.clear(
239
- deactivate_button,
240
- outputs=[download_glb],
241
- )
242
-
243
- # 메인 실행부
244
- if __name__ == "__main__":
245
- try:
246
- device = init_cuda()
247
- pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
248
- pipeline.to(device)
249
-
250
- # rembg 사전 로드 시도
251
- try:
252
- pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
253
- except Exception as e:
254
- print(f"사전 로드 중 오류 발생: {e}")
255
-
256
- # 공유 GPU 환경을 위한 설정으로 데모 실행
257
- demo.queue(max_size=10).launch(share=True)
258
- except Exception as e:
259
- print(f"애플리케이션 시작 중 오류 발생: {e}")
 
 
 
 
1
  import os
2
+ exec(os.environ.get('APP'))