gokaygokay John6666 commited on
Commit
e67f19a
·
verified ·
1 Parent(s): 6afb035

Update app.py (#4)

Browse files

- Update app.py (866d14323d35a3ee065a9475637f93b0b43ef07a)


Co-authored-by: John Smith <John6666@users.noreply.huggingface.co>

Files changed (1) hide show
  1. app.py +299 -299
app.py CHANGED
@@ -1,300 +1,300 @@
1
- import gradio as gr
2
- import spaces
3
- from gradio_litmodel3d import LitModel3D
4
- import os
5
- import shutil
6
- import random
7
- import uuid
8
- from datetime import datetime
9
- from diffusers import DiffusionPipeline
10
-
11
- os.environ['SPCONV_ALGO'] = 'native'
12
- from typing import *
13
- import torch
14
- import numpy as np
15
- import imageio
16
- from easydict import EasyDict as edict
17
- from PIL import Image
18
- from trellis.pipelines import TrellisImageTo3DPipeline
19
- from trellis.representations import Gaussian, MeshExtractResult
20
- from trellis.utils import render_utils, postprocessing_utils
21
-
22
- huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
23
- # Constants
24
- MAX_SEED = np.iinfo(np.int32).max
25
- TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
26
- os.makedirs(TMP_DIR, exist_ok=True)
27
-
28
- # Create permanent storage directory for Flux generated images
29
- SAVE_DIR = "saved_images"
30
- if not os.path.exists(SAVE_DIR):
31
- os.makedirs(SAVE_DIR, exist_ok=True)
32
-
33
- def start_session(req: gr.Request):
34
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
35
- os.makedirs(user_dir, exist_ok=True)
36
-
37
- def end_session(req: gr.Request):
38
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
39
- shutil.rmtree(user_dir)
40
-
41
- def preprocess_image(image: Image.Image) -> Image.Image:
42
- processed_image = trellis_pipeline.preprocess_image(image)
43
- return processed_image
44
-
45
- def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
46
- return {
47
- 'gaussian': {
48
- **gs.init_params,
49
- '_xyz': gs._xyz.cpu().numpy(),
50
- '_features_dc': gs._features_dc.cpu().numpy(),
51
- '_scaling': gs._scaling.cpu().numpy(),
52
- '_rotation': gs._rotation.cpu().numpy(),
53
- '_opacity': gs._opacity.cpu().numpy(),
54
- },
55
- 'mesh': {
56
- 'vertices': mesh.vertices.cpu().numpy(),
57
- 'faces': mesh.faces.cpu().numpy(),
58
- },
59
- }
60
-
61
- def unpack_state(state: dict) -> Tuple[Gaussian, edict]:
62
- gs = Gaussian(
63
- aabb=state['gaussian']['aabb'],
64
- sh_degree=state['gaussian']['sh_degree'],
65
- mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
66
- scaling_bias=state['gaussian']['scaling_bias'],
67
- opacity_bias=state['gaussian']['opacity_bias'],
68
- scaling_activation=state['gaussian']['scaling_activation'],
69
- )
70
- gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
71
- gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
72
- gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
73
- gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
74
- gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
75
-
76
- mesh = edict(
77
- vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
78
- faces=torch.tensor(state['mesh']['faces'], device='cuda'),
79
- )
80
-
81
- return gs, mesh
82
-
83
- def get_seed(randomize_seed: bool, seed: int) -> int:
84
- return np.random.randint(0, MAX_SEED) if randomize_seed else seed
85
-
86
- @spaces.GPU
87
- def generate_flux_image(
88
- prompt: str,
89
- seed: int,
90
- randomize_seed: bool,
91
- width: int,
92
- height: int,
93
- guidance_scale: float,
94
- num_inference_steps: int,
95
- lora_scale: float,
96
- progress: gr.Progress = gr.Progress(track_tqdm=True),
97
- ) -> Image.Image:
98
- """Generate image using Flux pipeline"""
99
- if randomize_seed:
100
- seed = random.randint(0, MAX_SEED)
101
- generator = torch.Generator(device=device).manual_seed(seed)
102
-
103
- image = flux_pipeline(
104
- prompt=prompt,
105
- guidance_scale=guidance_scale,
106
- num_inference_steps=num_inference_steps,
107
- width=width,
108
- height=height,
109
- generator=generator,
110
- joint_attention_kwargs={"scale": lora_scale},
111
- ).images[0]
112
-
113
- # Save the generated image
114
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
115
- unique_id = str(uuid.uuid4())[:8]
116
- filename = f"{timestamp}_{unique_id}.png"
117
- filepath = os.path.join(SAVE_DIR, filename)
118
- image.save(filepath)
119
-
120
- return image
121
-
122
- @spaces.GPU
123
- def image_to_3d(
124
- image: Image.Image,
125
- seed: int,
126
- ss_guidance_strength: float,
127
- ss_sampling_steps: int,
128
- slat_guidance_strength: float,
129
- slat_sampling_steps: int,
130
- req: gr.Request,
131
- ) -> Tuple[dict, str]:
132
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
133
- outputs = trellis_pipeline.run(
134
- image,
135
- seed=seed,
136
- formats=["gaussian", "mesh"],
137
- preprocess_image=False,
138
- sparse_structure_sampler_params={
139
- "steps": ss_sampling_steps,
140
- "cfg_strength": ss_guidance_strength,
141
- },
142
- slat_sampler_params={
143
- "steps": slat_sampling_steps,
144
- "cfg_strength": slat_guidance_strength,
145
- },
146
- )
147
- video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
148
- video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
149
- video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
150
- video_path = os.path.join(user_dir, 'sample.mp4')
151
- imageio.mimsave(video_path, video, fps=15)
152
- state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
153
- torch.cuda.empty_cache()
154
- return state, video_path
155
-
156
- @spaces.GPU(duration=90)
157
- def extract_glb(
158
- state: dict,
159
- mesh_simplify: float,
160
- texture_size: int,
161
- req: gr.Request,
162
- ) -> Tuple[str, str]:
163
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
164
- gs, mesh = unpack_state(state)
165
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
166
- glb_path = os.path.join(user_dir, 'sample.glb')
167
- glb.export(glb_path)
168
- torch.cuda.empty_cache()
169
- return glb_path, glb_path
170
-
171
- @spaces.GPU
172
- def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
173
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
174
- gs, _ = unpack_state(state)
175
- gaussian_path = os.path.join(user_dir, 'sample.ply')
176
- gs.save_ply(gaussian_path)
177
- torch.cuda.empty_cache()
178
- return gaussian_path, gaussian_path
179
-
180
- # Gradio Interface
181
- with gr.Blocks() as demo:
182
- gr.Markdown("""
183
- ## Game Asset Generation to 3D with FLUX and TRELLIS
184
- * Enter a prompt to generate a game asset image, then convert it to 3D
185
- * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
186
- """)
187
-
188
- with gr.Row():
189
- with gr.Column():
190
- # Flux image generation inputs
191
- prompt = gr.Text(label="Prompt", placeholder="Enter your game asset description")
192
- with gr.Accordion("Generation Settings", open=False):
193
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=42, step=1)
194
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
195
- with gr.Row():
196
- width = gr.Slider(256, 1024, label="Width", value=768, step=32)
197
- height = gr.Slider(256, 1024, label="Height", value=768, step=32)
198
- with gr.Row():
199
- guidance_scale = gr.Slider(0.0, 10.0, label="Guidance Scale", value=3.5, step=0.1)
200
- num_inference_steps = gr.Slider(1, 50, label="Steps", value=30, step=1)
201
- lora_scale = gr.Slider(0.0, 1.0, label="LoRA Scale", value=1.0, step=0.1)
202
-
203
- with gr.Accordion("3D Generation Settings", open=False):
204
- gr.Markdown("Stage 1: Sparse Structure Generation")
205
- with gr.Row():
206
- ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
207
- ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
208
- gr.Markdown("Stage 2: Structured Latent Generation")
209
- with gr.Row():
210
- slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
211
- slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
212
-
213
- generate_btn = gr.Button("Generate")
214
-
215
- with gr.Accordion("GLB Extraction Settings", open=False):
216
- mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
217
- texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
218
-
219
- with gr.Row():
220
- extract_glb_btn = gr.Button("Extract GLB", interactive=False)
221
- extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
222
-
223
- with gr.Column():
224
- generated_image = gr.Image(label="Generated Asset", type="pil")
225
- video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True)
226
- model_output = LitModel3D(label="Extracted GLB/Gaussian")
227
-
228
- with gr.Row():
229
- download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
230
- download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
231
-
232
- output_buf = gr.State()
233
-
234
- # Event handlers
235
- demo.load(start_session)
236
- demo.unload(end_session)
237
-
238
- generate_btn.click(
239
- generate_flux_image,
240
- inputs=[prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, lora_scale],
241
- outputs=[generated_image],
242
- ).then(
243
- image_to_3d,
244
- inputs=[generated_image, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
245
- outputs=[output_buf, video_output],
246
- ).then(
247
- lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
248
- outputs=[extract_glb_btn, extract_gs_btn],
249
- )
250
-
251
- extract_glb_btn.click(
252
- extract_glb,
253
- inputs=[output_buf, mesh_simplify, texture_size],
254
- outputs=[model_output, download_glb],
255
- ).then(
256
- lambda: gr.Button(interactive=True),
257
- outputs=[download_glb],
258
- )
259
-
260
- extract_gs_btn.click(
261
- extract_gaussian,
262
- inputs=[output_buf],
263
- outputs=[model_output, download_gs],
264
- ).then(
265
- lambda: gr.Button(interactive=True),
266
- outputs=[download_gs],
267
- )
268
-
269
- model_output.clear(
270
- lambda: gr.Button(interactive=False),
271
- outputs=[download_glb],
272
- )
273
-
274
- # Initialize both pipelines
275
- if __name__ == "__main__":
276
- from diffusers import FluxTransformer2DModel, FluxPipeline, BitsAndBytesConfig, BitsAndBytesConfigTF
277
- from transformers import T5EncoderModel
278
-
279
- # Initialize Flux pipeline
280
- device = "cuda" if torch.cuda.is_available() else "cpu"
281
- huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
282
-
283
- dtype = torch.bfloat16
284
- file_url = "https://huggingface.co/gokaygokay/flux-game/blob/main/gokaygokay_00001_.safetensors"
285
- single_file_base_model = "camenduru/FLUX.1-dev-diffusers"
286
- quantization_config_tf = BitsAndBytesConfigTF(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
287
- text_encoder_2 = T5EncoderModel.from_pretrained(single_file_base_model, subfolder="text_encoder_2", torch_dtype=dtype, config=single_file_base_model, quantization_config=quantization_config_tf, token=huggingface_token)
288
- quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
289
- transformer = FluxTransformer2DModel.from_single_file(file_url, subfolder="transformer", torch_dtype=dtype, config=single_file_base_model, quantization_config=quantization_config, token=huggingface_token)
290
- flux_pipeline = FluxPipeline.from_pretrained(single_file_base_model, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=dtype, quantization_config=quantization_config, token=huggingface_token)
291
-
292
- # Initialize Trellis pipeline
293
- trellis_pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
294
- trellis_pipeline.cuda()
295
- try:
296
- trellis_pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
297
- except:
298
- pass
299
-
300
  demo.launch()
 
1
+ import gradio as gr
2
+ import spaces
3
+ from gradio_litmodel3d import LitModel3D
4
+ import os
5
+ import shutil
6
+ import random
7
+ import uuid
8
+ from datetime import datetime
9
+ from diffusers import DiffusionPipeline
10
+
11
+ os.environ['SPCONV_ALGO'] = 'native'
12
+ from typing import *
13
+ import torch
14
+ import numpy as np
15
+ import imageio
16
+ from easydict import EasyDict as edict
17
+ from PIL import Image
18
+ from trellis.pipelines import TrellisImageTo3DPipeline
19
+ from trellis.representations import Gaussian, MeshExtractResult
20
+ from trellis.utils import render_utils, postprocessing_utils
21
+
22
+ huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
23
+ # Constants
24
+ MAX_SEED = np.iinfo(np.int32).max
25
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
26
+ os.makedirs(TMP_DIR, exist_ok=True)
27
+
28
+ # Create permanent storage directory for Flux generated images
29
+ SAVE_DIR = "saved_images"
30
+ if not os.path.exists(SAVE_DIR):
31
+ os.makedirs(SAVE_DIR, exist_ok=True)
32
+
33
+ def start_session(req: gr.Request):
34
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
35
+ os.makedirs(user_dir, exist_ok=True)
36
+
37
+ def end_session(req: gr.Request):
38
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
39
+ shutil.rmtree(user_dir)
40
+
41
+ def preprocess_image(image: Image.Image) -> Image.Image:
42
+ processed_image = trellis_pipeline.preprocess_image(image)
43
+ return processed_image
44
+
45
+ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
46
+ return {
47
+ 'gaussian': {
48
+ **gs.init_params,
49
+ '_xyz': gs._xyz.cpu().numpy(),
50
+ '_features_dc': gs._features_dc.cpu().numpy(),
51
+ '_scaling': gs._scaling.cpu().numpy(),
52
+ '_rotation': gs._rotation.cpu().numpy(),
53
+ '_opacity': gs._opacity.cpu().numpy(),
54
+ },
55
+ 'mesh': {
56
+ 'vertices': mesh.vertices.cpu().numpy(),
57
+ 'faces': mesh.faces.cpu().numpy(),
58
+ },
59
+ }
60
+
61
+ def unpack_state(state: dict) -> Tuple[Gaussian, edict]:
62
+ gs = Gaussian(
63
+ aabb=state['gaussian']['aabb'],
64
+ sh_degree=state['gaussian']['sh_degree'],
65
+ mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
66
+ scaling_bias=state['gaussian']['scaling_bias'],
67
+ opacity_bias=state['gaussian']['opacity_bias'],
68
+ scaling_activation=state['gaussian']['scaling_activation'],
69
+ )
70
+ gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
71
+ gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
72
+ gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
73
+ gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
74
+ gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
75
+
76
+ mesh = edict(
77
+ vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
78
+ faces=torch.tensor(state['mesh']['faces'], device='cuda'),
79
+ )
80
+
81
+ return gs, mesh
82
+
83
+ def get_seed(randomize_seed: bool, seed: int) -> int:
84
+ return np.random.randint(0, MAX_SEED) if randomize_seed else seed
85
+
86
+ @spaces.GPU
87
+ def generate_flux_image(
88
+ prompt: str,
89
+ seed: int,
90
+ randomize_seed: bool,
91
+ width: int,
92
+ height: int,
93
+ guidance_scale: float,
94
+ num_inference_steps: int,
95
+ lora_scale: float,
96
+ progress: gr.Progress = gr.Progress(track_tqdm=True),
97
+ ) -> Image.Image:
98
+ """Generate image using Flux pipeline"""
99
+ if randomize_seed:
100
+ seed = random.randint(0, MAX_SEED)
101
+ generator = torch.Generator(device=device).manual_seed(seed)
102
+
103
+ image = flux_pipeline(
104
+ prompt=prompt,
105
+ guidance_scale=guidance_scale,
106
+ num_inference_steps=num_inference_steps,
107
+ width=width,
108
+ height=height,
109
+ generator=generator,
110
+ joint_attention_kwargs={"scale": lora_scale},
111
+ ).images[0]
112
+
113
+ # Save the generated image
114
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
115
+ unique_id = str(uuid.uuid4())[:8]
116
+ filename = f"{timestamp}_{unique_id}.png"
117
+ filepath = os.path.join(SAVE_DIR, filename)
118
+ image.save(filepath)
119
+
120
+ return image
121
+
122
+ @spaces.GPU
123
+ def image_to_3d(
124
+ image: Image.Image,
125
+ seed: int,
126
+ ss_guidance_strength: float,
127
+ ss_sampling_steps: int,
128
+ slat_guidance_strength: float,
129
+ slat_sampling_steps: int,
130
+ req: gr.Request,
131
+ ) -> Tuple[dict, str]:
132
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
133
+ outputs = trellis_pipeline.run(
134
+ image,
135
+ seed=seed,
136
+ formats=["gaussian", "mesh"],
137
+ preprocess_image=False,
138
+ sparse_structure_sampler_params={
139
+ "steps": ss_sampling_steps,
140
+ "cfg_strength": ss_guidance_strength,
141
+ },
142
+ slat_sampler_params={
143
+ "steps": slat_sampling_steps,
144
+ "cfg_strength": slat_guidance_strength,
145
+ },
146
+ )
147
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
148
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
149
+ video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
150
+ video_path = os.path.join(user_dir, 'sample.mp4')
151
+ imageio.mimsave(video_path, video, fps=15)
152
+ state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
153
+ torch.cuda.empty_cache()
154
+ return state, video_path
155
+
156
+ @spaces.GPU(duration=90)
157
+ def extract_glb(
158
+ state: dict,
159
+ mesh_simplify: float,
160
+ texture_size: int,
161
+ req: gr.Request,
162
+ ) -> Tuple[str, str]:
163
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
164
+ gs, mesh = unpack_state(state)
165
+ glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
166
+ glb_path = os.path.join(user_dir, 'sample.glb')
167
+ glb.export(glb_path)
168
+ torch.cuda.empty_cache()
169
+ return glb_path, glb_path
170
+
171
+ @spaces.GPU
172
+ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
173
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
174
+ gs, _ = unpack_state(state)
175
+ gaussian_path = os.path.join(user_dir, 'sample.ply')
176
+ gs.save_ply(gaussian_path)
177
+ torch.cuda.empty_cache()
178
+ return gaussian_path, gaussian_path
179
+
180
+ # Gradio Interface
181
+ with gr.Blocks() as demo:
182
+ gr.Markdown("""
183
+ ## Game Asset Generation to 3D with FLUX and TRELLIS
184
+ * Enter a prompt to generate a game asset image, then convert it to 3D
185
+ * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
186
+ """)
187
+
188
+ with gr.Row():
189
+ with gr.Column():
190
+ # Flux image generation inputs
191
+ prompt = gr.Text(label="Prompt", placeholder="Enter your game asset description")
192
+ with gr.Accordion("Generation Settings", open=False):
193
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=42, step=1)
194
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
195
+ with gr.Row():
196
+ width = gr.Slider(256, 1024, label="Width", value=768, step=32)
197
+ height = gr.Slider(256, 1024, label="Height", value=768, step=32)
198
+ with gr.Row():
199
+ guidance_scale = gr.Slider(0.0, 10.0, label="Guidance Scale", value=3.5, step=0.1)
200
+ num_inference_steps = gr.Slider(1, 50, label="Steps", value=30, step=1)
201
+ lora_scale = gr.Slider(0.0, 1.0, label="LoRA Scale", value=1.0, step=0.1)
202
+
203
+ with gr.Accordion("3D Generation Settings", open=False):
204
+ gr.Markdown("Stage 1: Sparse Structure Generation")
205
+ with gr.Row():
206
+ ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
207
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
208
+ gr.Markdown("Stage 2: Structured Latent Generation")
209
+ with gr.Row():
210
+ slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
211
+ slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
212
+
213
+ generate_btn = gr.Button("Generate")
214
+
215
+ with gr.Accordion("GLB Extraction Settings", open=False):
216
+ mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
217
+ texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
218
+
219
+ with gr.Row():
220
+ extract_glb_btn = gr.Button("Extract GLB", interactive=False)
221
+ extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
222
+
223
+ with gr.Column():
224
+ generated_image = gr.Image(label="Generated Asset", type="pil")
225
+ video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True)
226
+ model_output = LitModel3D(label="Extracted GLB/Gaussian")
227
+
228
+ with gr.Row():
229
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
230
+ download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
231
+
232
+ output_buf = gr.State()
233
+
234
+ # Event handlers
235
+ demo.load(start_session)
236
+ demo.unload(end_session)
237
+
238
+ generate_btn.click(
239
+ generate_flux_image,
240
+ inputs=[prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, lora_scale],
241
+ outputs=[generated_image],
242
+ ).then(
243
+ image_to_3d,
244
+ inputs=[generated_image, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
245
+ outputs=[output_buf, video_output],
246
+ ).then(
247
+ lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
248
+ outputs=[extract_glb_btn, extract_gs_btn],
249
+ )
250
+
251
+ extract_glb_btn.click(
252
+ extract_glb,
253
+ inputs=[output_buf, mesh_simplify, texture_size],
254
+ outputs=[model_output, download_glb],
255
+ ).then(
256
+ lambda: gr.Button(interactive=True),
257
+ outputs=[download_glb],
258
+ )
259
+
260
+ extract_gs_btn.click(
261
+ extract_gaussian,
262
+ inputs=[output_buf],
263
+ outputs=[model_output, download_gs],
264
+ ).then(
265
+ lambda: gr.Button(interactive=True),
266
+ outputs=[download_gs],
267
+ )
268
+
269
+ model_output.clear(
270
+ lambda: gr.Button(interactive=False),
271
+ outputs=[download_glb],
272
+ )
273
+
274
+ # Initialize both pipelines
275
+ if __name__ == "__main__":
276
+ from diffusers import FluxTransformer2DModel, FluxPipeline, BitsAndBytesConfig
277
+ from transformers import T5EncoderModel, BitsAndBytesConfig as BitsAndBytesConfigTF
278
+
279
+ # Initialize Flux pipeline
280
+ device = "cuda" if torch.cuda.is_available() else "cpu"
281
+ huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
282
+
283
+ dtype = torch.bfloat16
284
+ file_url = "https://huggingface.co/gokaygokay/flux-game/blob/main/gokaygokay_00001_.safetensors"
285
+ single_file_base_model = "camenduru/FLUX.1-dev-diffusers"
286
+ quantization_config_tf = BitsAndBytesConfigTF(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
287
+ text_encoder_2 = T5EncoderModel.from_pretrained(single_file_base_model, subfolder="text_encoder_2", torch_dtype=dtype, config=single_file_base_model, quantization_config=quantization_config_tf, token=huggingface_token)
288
+ quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
289
+ transformer = FluxTransformer2DModel.from_single_file(file_url, subfolder="transformer", torch_dtype=dtype, config=single_file_base_model, quantization_config=quantization_config, token=huggingface_token)
290
+ flux_pipeline = FluxPipeline.from_pretrained(single_file_base_model, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=dtype, quantization_config=quantization_config, token=huggingface_token)
291
+
292
+ # Initialize Trellis pipeline
293
+ trellis_pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
294
+ trellis_pipeline.cuda()
295
+ try:
296
+ trellis_pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
297
+ except:
298
+ pass
299
+
300
  demo.launch()