aiqtech commited on
Commit
5201a38
ยท
verified ยท
1 Parent(s): 9df34dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -56
app.py CHANGED
@@ -23,45 +23,39 @@ os.makedirs(TMP_DIR, exist_ok=True)
23
  def initialize_models():
24
  global pipeline, translator, flux_pipe
25
 
26
- # CUDA ๋ฉ”๋ชจ๋ฆฌ ์ดˆ๊ธฐํ™”
27
- if torch.cuda.is_available():
28
- torch.cuda.empty_cache()
29
-
30
  try:
 
 
 
 
 
 
31
  # Trellis ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™”
32
- pipeline = TrellisImageTo3DPipeline.from_pretrained(
33
- "JeffreyXiang/TRELLIS-image-large",
34
- device_map="auto" # Zero GPU ํ™˜๊ฒฝ์— ๋งž๊ฒŒ ์ž๋™ device ๋งคํ•‘
35
- )
36
 
37
- # ๋ฒˆ์—ญ๊ธฐ ์ดˆ๊ธฐํ™”
38
  translator = translation_pipeline(
39
  "translation",
40
  model="Helsinki-NLP/opus-mt-ko-en",
41
- device_map="auto"
42
  )
43
 
44
- # Flux ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™”
45
  flux_pipe = FluxPipeline.from_pretrained(
46
  "black-forest-labs/FLUX.1-dev",
47
- torch_dtype=torch.float16, # bfloat16 ๋Œ€์‹  float16 ์‚ฌ์šฉ
48
- device_map="auto"
49
  )
50
 
51
- # LoRA ๊ฐ€์ค‘์น˜ ๋กœ๋“œ
52
- flux_pipe.load_lora_weights(
53
- "gokaygokay/Flux-Game-Assets-LoRA-v2",
54
- device_map="auto"
55
- )
56
- flux_pipe.fuse_lora(lora_scale=1.0)
57
 
58
  except Exception as e:
59
- print(f"Error initializing models: {str(e)}")
60
- if torch.cuda.is_available():
61
- torch.cuda.empty_cache()
62
- raise e
63
-
64
-
65
 
66
  def translate_if_korean(text):
67
  if any(ord('๊ฐ€') <= ord(char) <= ord('ํžฃ') for char in text):
@@ -116,33 +110,32 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
116
  return gs, mesh, state['trial_id']
117
 
118
  @spaces.GPU
119
- def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_strength: float,
120
  ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int):
121
  try:
122
- if torch.cuda.is_available():
123
- torch.cuda.empty_cache()
124
-
125
  if randomize_seed:
126
  seed = np.random.randint(0, MAX_SEED)
127
 
128
  input_image = Image.open(f"{TMP_DIR}/{trial_id}.png")
129
 
130
- # ๋ฉ”๋ชจ๋ฆฌ ์ตœ์ ํ™”๋ฅผ ์œ„ํ•œ ์ปจํ…์ŠคํŠธ ๋งค๋‹ˆ์ € ์‚ฌ์šฉ
131
- with torch.cuda.amp.autocast(enabled=True):
132
- outputs = pipeline.run(
133
- input_image,
134
- seed=seed,
135
- formats=["gaussian", "mesh"],
136
- preprocess_image=False,
137
- sparse_structure_sampler_params={
138
- "steps": ss_sampling_steps,
139
- "cfg_strength": ss_guidance_strength,
140
- },
141
- slat_sampler_params={
142
- "steps": slat_sampling_steps,
143
- "cfg_strength": slat_guidance_strength,
144
- }
145
- )
146
 
147
  # ๋น„๋””์˜ค ๋ Œ๋”๋ง
148
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
@@ -163,8 +156,7 @@ def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_stre
163
 
164
  except Exception as e:
165
  print(f"Error in image_to_3d: {str(e)}")
166
- if torch.cuda.is_available():
167
- torch.cuda.empty_cache()
168
  raise e
169
 
170
  @spaces.GPU
@@ -334,17 +326,25 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
334
  )
335
 
336
  if __name__ == "__main__":
337
- # CUDA ๋ฉ”๋ชจ๋ฆฌ ์บ์‹œ ์ดˆ๊ธฐํ™”
338
- torch.cuda.empty_cache()
 
339
 
340
- # ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
341
- initialize_models()
 
 
342
 
343
  try:
344
- # rembg ์‚ฌ์ „ ๋กœ๋“œ
345
- pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
346
- except:
347
- pass
 
348
 
349
  # Gradio ์•ฑ ์‹คํ–‰
350
- demo.launch(share=True) # share=True ์ถ”๊ฐ€
 
 
 
 
 
23
  def initialize_models():
24
  global pipeline, translator, flux_pipe
25
 
 
 
 
 
26
  try:
27
+ # GPU ๋ฉ”๋ชจ๋ฆฌ ์ดˆ๊ธฐํ™”
28
+ torch.cuda.empty_cache()
29
+
30
+ # GPU ์‚ฌ์šฉ ๊ฐ€๋Šฅ ์—ฌ๋ถ€ ํ™•์ธ
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+
33
  # Trellis ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™”
34
+ pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
35
+ pipeline.to(device)
 
 
36
 
37
+ # ๋ฒˆ์—ญ๊ธฐ ์ดˆ๊ธฐํ™”
38
  translator = translation_pipeline(
39
  "translation",
40
  model="Helsinki-NLP/opus-mt-ko-en",
41
+ device=0 if device=="cuda" else -1
42
  )
43
 
44
+ # Flux ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™”
45
  flux_pipe = FluxPipeline.from_pretrained(
46
  "black-forest-labs/FLUX.1-dev",
47
+ torch_dtype=torch.float16 if device=="cuda" else torch.float32
 
48
  )
49
 
50
+ if device == "cuda":
51
+ flux_pipe.enable_model_cpu_offload()
52
+
53
+ return True
 
 
54
 
55
  except Exception as e:
56
+ print(f"Model initialization error: {str(e)}")
57
+ torch.cuda.empty_cache()
58
+ return False
 
 
 
59
 
60
  def translate_if_korean(text):
61
  if any(ord('๊ฐ€') <= ord(char) <= ord('ํžฃ') for char in text):
 
110
  return gs, mesh, state['trial_id']
111
 
112
  @spaces.GPU
113
+ def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_strength: float,
114
  ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int):
115
  try:
116
+ torch.cuda.empty_cache()
117
+
 
118
  if randomize_seed:
119
  seed = np.random.randint(0, MAX_SEED)
120
 
121
  input_image = Image.open(f"{TMP_DIR}/{trial_id}.png")
122
 
123
+ with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
124
+ with torch.no_grad():
125
+ outputs = pipeline.run(
126
+ input_image,
127
+ seed=seed,
128
+ formats=["gaussian", "mesh"],
129
+ preprocess_image=False,
130
+ sparse_structure_sampler_params={
131
+ "steps": ss_sampling_steps,
132
+ "cfg_strength": ss_guidance_strength,
133
+ },
134
+ slat_sampler_params={
135
+ "steps": slat_sampling_steps,
136
+ "cfg_strength": slat_guidance_strength,
137
+ }
138
+ )
139
 
140
  # ๋น„๋””์˜ค ๋ Œ๋”๋ง
141
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
 
156
 
157
  except Exception as e:
158
  print(f"Error in image_to_3d: {str(e)}")
159
+ torch.cuda.empty_cache()
 
160
  raise e
161
 
162
  @spaces.GPU
 
326
  )
327
 
328
  if __name__ == "__main__":
329
+ # ์ดˆ๊ธฐ GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
330
+ if torch.cuda.is_available():
331
+ torch.cuda.empty_cache()
332
 
333
+ # ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ํ™•์ธ
334
+ if not initialize_models():
335
+ print("Failed to initialize models")
336
+ exit(1)
337
 
338
  try:
339
+ # rembg ์‚ฌ์ „ ๋กœ๋“œ ์‹œ๋„
340
+ test_image = Image.fromarray(np.zeros((256, 256, 3), dtype=np.uint8))
341
+ pipeline.preprocess_image(test_image)
342
+ except Exception as e:
343
+ print(f"Warning: Failed to preload rembg: {str(e)}")
344
 
345
  # Gradio ์•ฑ ์‹คํ–‰
346
+ demo.queue(concurrency_count=1).launch(
347
+ share=True,
348
+ enable_queue=True,
349
+ max_threads=1
350
+ )