aiqtech commited on
Commit
1f5cf77
ยท
verified ยท
1 Parent(s): 9f57959

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -72
app.py CHANGED
@@ -21,15 +21,10 @@ TMP_DIR = "/tmp/Trellis-demo"
21
  os.makedirs(TMP_DIR, exist_ok=True)
22
 
23
 
24
- # ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ จ ํ™˜๊ฒฝ ๋ณ€์ˆ˜
25
- os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:32'
26
- os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
27
- os.environ['TORCH_HOME'] = '/tmp/torch_home'
28
- os.environ['HF_HOME'] = '/tmp/huggingface'
29
- os.environ['XDG_CACHE_HOME'] = '/tmp/cache'
30
- os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
31
- os.environ['SPCONV_ALGO'] = 'native'
32
- os.environ['WARP_USE_CPU'] = '1'
33
 
34
  def initialize_models():
35
  global pipeline, translator, flux_pipe
@@ -37,33 +32,27 @@ def initialize_models():
37
  try:
38
  import torch
39
 
40
- # ๋ฉ”๋ชจ๋ฆฌ ์„ค์ •
41
- torch.backends.cudnn.benchmark = False
42
- torch.backends.cudnn.deterministic = True
 
43
 
44
  print("Initializing Trellis pipeline...")
45
- # Trellis ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™”
46
  pipeline = TrellisImageTo3DPipeline.from_pretrained(
47
- "JeffreyXiang/TRELLIS-image-large"
 
48
  )
49
 
50
- if pipeline is None:
51
- raise Exception("Failed to initialize Trellis pipeline")
52
 
53
  print("Initializing translator...")
54
- # ๋ฒˆ์—ญ๊ธฐ ์ดˆ๊ธฐํ™”
55
  translator = translation_pipeline(
56
  "translation",
57
  model="Helsinki-NLP/opus-mt-ko-en",
58
- device="cpu"
59
  )
60
 
61
- if translator is None:
62
- raise Exception("Failed to initialize translator")
63
-
64
- # Flux ํŒŒ์ดํ”„๋ผ์ธ์€ ๋‚˜์ค‘์— ์ดˆ๊ธฐํ™”
65
- flux_pipe = None
66
-
67
  print("Models initialized successfully")
68
  return True
69
 
@@ -79,15 +68,17 @@ def get_flux_pipe():
79
  free_memory()
80
  flux_pipe = FluxPipeline.from_pretrained(
81
  "black-forest-labs/FLUX.1-dev",
82
- torch_dtype=torch.float32, # CPU ๋ชจ๋“œ๋กœ ์‹œ์ž‘
83
  use_safetensors=True
84
- )
85
  except Exception as e:
86
  print(f"Error loading Flux pipeline: {e}")
87
  return None
88
  return flux_pipe
89
 
90
 
 
 
91
  def free_memory():
92
  """๊ฐ•ํ™”๋œ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ํ•จ์ˆ˜"""
93
  import gc
@@ -265,7 +256,7 @@ def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_stre
265
  @spaces.GPU
266
  def generate_image_from_text(prompt, height, width, guidance_scale, num_steps):
267
  try:
268
- free_memory() # ์‹œ์ž‘ ์ „ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
269
 
270
  # Flux ํŒŒ์ดํ”„๋ผ์ธ ๊ฐ€์ ธ์˜ค๊ธฐ
271
  flux_pipe = get_flux_pipe()
@@ -273,25 +264,27 @@ def generate_image_from_text(prompt, height, width, guidance_scale, num_steps):
273
  raise Exception("Failed to load Flux pipeline")
274
 
275
  # ์ด๋ฏธ์ง€ ํฌ๊ธฐ ์ œํ•œ
276
- height = min(height, 512)
277
- width = min(width, 512)
278
 
279
  # ํ”„๋กฌํ”„ํŠธ ์ฒ˜๋ฆฌ
280
  base_prompt = "wbgmsst, 3D, white background"
281
  translated_prompt = translate_if_korean(prompt)
282
  final_prompt = f"{translated_prompt}, {base_prompt}"
283
 
284
- with torch.inference_mode(), torch.cuda.amp.autocast():
285
  output = flux_pipe(
286
  prompt=[final_prompt],
287
  height=height,
288
  width=width,
289
- guidance_scale=min(guidance_scale, 7.5), # ๋‚ฎ์€ ๊ฐ’์œผ๋กœ ์ œํ•œ
290
- num_inference_steps=min(num_steps, 20) # ์Šคํ… ์ˆ˜ ์ œํ•œ
 
291
  )
292
- image = output.images[0]
 
293
 
294
- free_memory() # ์™„๋ฃŒ ํ›„ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
295
  return image
296
 
297
  except Exception as e:
@@ -444,50 +437,27 @@ if __name__ == "__main__":
444
  import warnings
445
  warnings.filterwarnings('ignore')
446
 
 
 
 
 
 
447
  # ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
448
  os.makedirs(TMP_DIR, exist_ok=True)
449
 
450
  # ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
451
  free_memory()
452
 
453
- # ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์‹œ๋„
454
- retry_count = 3
455
- initialized = False
456
-
457
- for i in range(retry_count):
458
- try:
459
- if initialize_models():
460
- initialized = True
461
- break
462
- else:
463
- print(f"Initialization attempt {i+1} failed, retrying...")
464
- free_memory()
465
- except Exception as e:
466
- print(f"Error during initialization attempt {i+1}: {str(e)}")
467
- free_memory()
468
-
469
- if not initialized:
470
- print("Failed to initialize models after multiple attempts")
471
  exit(1)
472
 
473
- try:
474
- # rembg ์‚ฌ์ „ ๋กœ๋“œ ์‹œ๋„
475
- test_image = Image.fromarray(np.ones((32, 32, 3), dtype=np.uint8) * 255)
476
- if pipeline is not None:
477
- pipeline.preprocess_image(test_image)
478
- except Exception as e:
479
- print(f"Warning: Failed to preload rembg: {str(e)}")
480
-
481
  # Gradio ์•ฑ ์‹คํ–‰
482
- try:
483
- demo.queue(max_size=1).launch(
484
- share=True,
485
- max_threads=1,
486
- show_error=True,
487
- server_port=7860,
488
- server_name="0.0.0.0",
489
- quiet=True
490
- )
491
- except Exception as e:
492
- print(f"Error launching Gradio app: {str(e)}")
493
- exit(1)
 
21
  os.makedirs(TMP_DIR, exist_ok=True)
22
 
23
 
24
+ # GPU ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ จ ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์ˆ˜์ •
25
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512' # A100์— ๋งž๊ฒŒ ์ฆ๊ฐ€
26
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0' # ๋‹จ์ผ GPU ์‚ฌ์šฉ
27
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '0' # A100์—์„œ๋Š” ๋น„๋™๊ธฐ ์‹คํ–‰ ํ—ˆ์šฉ
 
 
 
 
 
28
 
29
  def initialize_models():
30
  global pipeline, translator, flux_pipe
 
32
  try:
33
  import torch
34
 
35
+ # A100 ์ตœ์ ํ™” ์„ค์ •
36
+ torch.backends.cudnn.benchmark = True # A100์—์„œ๋Š” ์„ฑ๋Šฅ ํ–ฅ์ƒ์„ ์œ„ํ•ด ํ™œ์„ฑํ™”
37
+ torch.backends.cuda.matmul.allow_tf32 = True # TF32 ํ—ˆ์šฉ
38
+ torch.backends.cudnn.allow_tf32 = True
39
 
40
  print("Initializing Trellis pipeline...")
 
41
  pipeline = TrellisImageTo3DPipeline.from_pretrained(
42
+ "JeffreyXiang/TRELLIS-image-large",
43
+ torch_dtype=torch.float16 # A100์—์„œ FP16 ์‚ฌ์šฉ
44
  )
45
 
46
+ if torch.cuda.is_available():
47
+ pipeline = pipeline.to("cuda")
48
 
49
  print("Initializing translator...")
 
50
  translator = translation_pipeline(
51
  "translation",
52
  model="Helsinki-NLP/opus-mt-ko-en",
53
+ device="cuda" # ๋ฒˆ์—ญ๊ธฐ๋„ GPU ์‚ฌ์šฉ
54
  )
55
 
 
 
 
 
 
 
56
  print("Models initialized successfully")
57
  return True
58
 
 
68
  free_memory()
69
  flux_pipe = FluxPipeline.from_pretrained(
70
  "black-forest-labs/FLUX.1-dev",
71
+ torch_dtype=torch.float16, # A100์—์„œ FP16 ์‚ฌ์šฉ
72
  use_safetensors=True
73
+ ).to("cuda")
74
  except Exception as e:
75
  print(f"Error loading Flux pipeline: {e}")
76
  return None
77
  return flux_pipe
78
 
79
 
80
+
81
+
82
  def free_memory():
83
  """๊ฐ•ํ™”๋œ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ํ•จ์ˆ˜"""
84
  import gc
 
256
  @spaces.GPU
257
  def generate_image_from_text(prompt, height, width, guidance_scale, num_steps):
258
  try:
259
+ free_memory()
260
 
261
  # Flux ํŒŒ์ดํ”„๋ผ์ธ ๊ฐ€์ ธ์˜ค๊ธฐ
262
  flux_pipe = get_flux_pipe()
 
264
  raise Exception("Failed to load Flux pipeline")
265
 
266
  # ์ด๋ฏธ์ง€ ํฌ๊ธฐ ์ œํ•œ
267
+ height = min(height, 1024) # A100์—์„œ๋Š” ๋” ํฐ ์ด๋ฏธ์ง€ ํ—ˆ์šฉ
268
+ width = min(width, 1024)
269
 
270
  # ํ”„๋กฌํ”„ํŠธ ์ฒ˜๋ฆฌ
271
  base_prompt = "wbgmsst, 3D, white background"
272
  translated_prompt = translate_if_korean(prompt)
273
  final_prompt = f"{translated_prompt}, {base_prompt}"
274
 
275
+ with torch.cuda.amp.autocast(): # A100์—์„œ ์ž๋™ ํ˜ผํ•ฉ ์ •๋ฐ€๋„ ์‚ฌ์šฉ
276
  output = flux_pipe(
277
  prompt=[final_prompt],
278
  height=height,
279
  width=width,
280
+ guidance_scale=guidance_scale,
281
+ num_inference_steps=num_steps,
282
+ generator=torch.Generator(device='cuda')
283
  )
284
+
285
+ image = output.images[0]
286
 
287
+ free_memory()
288
  return image
289
 
290
  except Exception as e:
 
437
  import warnings
438
  warnings.filterwarnings('ignore')
439
 
440
+ # CUDA ์„ค์ • ํ™•์ธ
441
+ if torch.cuda.is_available():
442
+ print(f"Using GPU: {torch.cuda.get_device_name()}")
443
+ print(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
444
+
445
  # ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
446
  os.makedirs(TMP_DIR, exist_ok=True)
447
 
448
  # ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
449
  free_memory()
450
 
451
+ # ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
452
+ if not initialize_models():
453
+ print("Failed to initialize models")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
  exit(1)
455
 
 
 
 
 
 
 
 
 
456
  # Gradio ์•ฑ ์‹คํ–‰
457
+ demo.queue(max_size=2).launch( # ํ ํฌ๊ธฐ ์ฆ๊ฐ€
458
+ share=True,
459
+ max_threads=4, # ์Šค๋ ˆ๋“œ ์ˆ˜ ์ฆ๊ฐ€
460
+ show_error=True,
461
+ server_port=7860,
462
+ server_name="0.0.0.0"
463
+ )