Spaces:
Running
on
L40S
Running
on
L40S
Update app.py
Browse files
app.py
CHANGED
@@ -21,15 +21,10 @@ TMP_DIR = "/tmp/Trellis-demo"
|
|
21 |
os.makedirs(TMP_DIR, exist_ok=True)
|
22 |
|
23 |
|
24 |
-
# ๋ฉ๋ชจ๋ฆฌ ๊ด๋ จ ํ๊ฒฝ ๋ณ์
|
25 |
-
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:
|
26 |
-
os.environ['
|
27 |
-
os.environ['
|
28 |
-
os.environ['HF_HOME'] = '/tmp/huggingface'
|
29 |
-
os.environ['XDG_CACHE_HOME'] = '/tmp/cache'
|
30 |
-
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
31 |
-
os.environ['SPCONV_ALGO'] = 'native'
|
32 |
-
os.environ['WARP_USE_CPU'] = '1'
|
33 |
|
34 |
def initialize_models():
|
35 |
global pipeline, translator, flux_pipe
|
@@ -37,33 +32,27 @@ def initialize_models():
|
|
37 |
try:
|
38 |
import torch
|
39 |
|
40 |
-
#
|
41 |
-
torch.backends.cudnn.benchmark =
|
42 |
-
torch.backends.
|
|
|
43 |
|
44 |
print("Initializing Trellis pipeline...")
|
45 |
-
# Trellis ํ์ดํ๋ผ์ธ ์ด๊ธฐํ
|
46 |
pipeline = TrellisImageTo3DPipeline.from_pretrained(
|
47 |
-
"JeffreyXiang/TRELLIS-image-large"
|
|
|
48 |
)
|
49 |
|
50 |
-
if
|
51 |
-
|
52 |
|
53 |
print("Initializing translator...")
|
54 |
-
# ๋ฒ์ญ๊ธฐ ์ด๊ธฐํ
|
55 |
translator = translation_pipeline(
|
56 |
"translation",
|
57 |
model="Helsinki-NLP/opus-mt-ko-en",
|
58 |
-
device="
|
59 |
)
|
60 |
|
61 |
-
if translator is None:
|
62 |
-
raise Exception("Failed to initialize translator")
|
63 |
-
|
64 |
-
# Flux ํ์ดํ๋ผ์ธ์ ๋์ค์ ์ด๊ธฐํ
|
65 |
-
flux_pipe = None
|
66 |
-
|
67 |
print("Models initialized successfully")
|
68 |
return True
|
69 |
|
@@ -79,15 +68,17 @@ def get_flux_pipe():
|
|
79 |
free_memory()
|
80 |
flux_pipe = FluxPipeline.from_pretrained(
|
81 |
"black-forest-labs/FLUX.1-dev",
|
82 |
-
torch_dtype=torch.
|
83 |
use_safetensors=True
|
84 |
-
)
|
85 |
except Exception as e:
|
86 |
print(f"Error loading Flux pipeline: {e}")
|
87 |
return None
|
88 |
return flux_pipe
|
89 |
|
90 |
|
|
|
|
|
91 |
def free_memory():
|
92 |
"""๊ฐํ๋ ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ ํจ์"""
|
93 |
import gc
|
@@ -265,7 +256,7 @@ def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_stre
|
|
265 |
@spaces.GPU
|
266 |
def generate_image_from_text(prompt, height, width, guidance_scale, num_steps):
|
267 |
try:
|
268 |
-
free_memory()
|
269 |
|
270 |
# Flux ํ์ดํ๋ผ์ธ ๊ฐ์ ธ์ค๊ธฐ
|
271 |
flux_pipe = get_flux_pipe()
|
@@ -273,25 +264,27 @@ def generate_image_from_text(prompt, height, width, guidance_scale, num_steps):
|
|
273 |
raise Exception("Failed to load Flux pipeline")
|
274 |
|
275 |
# ์ด๋ฏธ์ง ํฌ๊ธฐ ์ ํ
|
276 |
-
height = min(height,
|
277 |
-
width = min(width,
|
278 |
|
279 |
# ํ๋กฌํํธ ์ฒ๋ฆฌ
|
280 |
base_prompt = "wbgmsst, 3D, white background"
|
281 |
translated_prompt = translate_if_korean(prompt)
|
282 |
final_prompt = f"{translated_prompt}, {base_prompt}"
|
283 |
|
284 |
-
with torch.
|
285 |
output = flux_pipe(
|
286 |
prompt=[final_prompt],
|
287 |
height=height,
|
288 |
width=width,
|
289 |
-
guidance_scale=
|
290 |
-
num_inference_steps=
|
|
|
291 |
)
|
292 |
-
|
|
|
293 |
|
294 |
-
free_memory()
|
295 |
return image
|
296 |
|
297 |
except Exception as e:
|
@@ -444,50 +437,27 @@ if __name__ == "__main__":
|
|
444 |
import warnings
|
445 |
warnings.filterwarnings('ignore')
|
446 |
|
|
|
|
|
|
|
|
|
|
|
447 |
# ๋๋ ํ ๋ฆฌ ์์ฑ
|
448 |
os.makedirs(TMP_DIR, exist_ok=True)
|
449 |
|
450 |
# ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ
|
451 |
free_memory()
|
452 |
|
453 |
-
# ๋ชจ๋ธ ์ด๊ธฐํ
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
for i in range(retry_count):
|
458 |
-
try:
|
459 |
-
if initialize_models():
|
460 |
-
initialized = True
|
461 |
-
break
|
462 |
-
else:
|
463 |
-
print(f"Initialization attempt {i+1} failed, retrying...")
|
464 |
-
free_memory()
|
465 |
-
except Exception as e:
|
466 |
-
print(f"Error during initialization attempt {i+1}: {str(e)}")
|
467 |
-
free_memory()
|
468 |
-
|
469 |
-
if not initialized:
|
470 |
-
print("Failed to initialize models after multiple attempts")
|
471 |
exit(1)
|
472 |
|
473 |
-
try:
|
474 |
-
# rembg ์ฌ์ ๋ก๋ ์๋
|
475 |
-
test_image = Image.fromarray(np.ones((32, 32, 3), dtype=np.uint8) * 255)
|
476 |
-
if pipeline is not None:
|
477 |
-
pipeline.preprocess_image(test_image)
|
478 |
-
except Exception as e:
|
479 |
-
print(f"Warning: Failed to preload rembg: {str(e)}")
|
480 |
-
|
481 |
# Gradio ์ฑ ์คํ
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
quiet=True
|
490 |
-
)
|
491 |
-
except Exception as e:
|
492 |
-
print(f"Error launching Gradio app: {str(e)}")
|
493 |
-
exit(1)
|
|
|
21 |
os.makedirs(TMP_DIR, exist_ok=True)
|
22 |
|
23 |
|
24 |
+
# GPU ๋ฉ๋ชจ๋ฆฌ ๊ด๋ จ ํ๊ฒฝ ๋ณ์ ์์
|
25 |
+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512' # A100์ ๋ง๊ฒ ์ฆ๊ฐ
|
26 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # ๋จ์ผ GPU ์ฌ์ฉ
|
27 |
+
os.environ['CUDA_LAUNCH_BLOCKING'] = '0' # A100์์๋ ๋น๋๊ธฐ ์คํ ํ์ฉ
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
def initialize_models():
|
30 |
global pipeline, translator, flux_pipe
|
|
|
32 |
try:
|
33 |
import torch
|
34 |
|
35 |
+
# A100 ์ต์ ํ ์ค์
|
36 |
+
torch.backends.cudnn.benchmark = True # A100์์๋ ์ฑ๋ฅ ํฅ์์ ์ํด ํ์ฑํ
|
37 |
+
torch.backends.cuda.matmul.allow_tf32 = True # TF32 ํ์ฉ
|
38 |
+
torch.backends.cudnn.allow_tf32 = True
|
39 |
|
40 |
print("Initializing Trellis pipeline...")
|
|
|
41 |
pipeline = TrellisImageTo3DPipeline.from_pretrained(
|
42 |
+
"JeffreyXiang/TRELLIS-image-large",
|
43 |
+
torch_dtype=torch.float16 # A100์์ FP16 ์ฌ์ฉ
|
44 |
)
|
45 |
|
46 |
+
if torch.cuda.is_available():
|
47 |
+
pipeline = pipeline.to("cuda")
|
48 |
|
49 |
print("Initializing translator...")
|
|
|
50 |
translator = translation_pipeline(
|
51 |
"translation",
|
52 |
model="Helsinki-NLP/opus-mt-ko-en",
|
53 |
+
device="cuda" # ๋ฒ์ญ๊ธฐ๋ GPU ์ฌ์ฉ
|
54 |
)
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
print("Models initialized successfully")
|
57 |
return True
|
58 |
|
|
|
68 |
free_memory()
|
69 |
flux_pipe = FluxPipeline.from_pretrained(
|
70 |
"black-forest-labs/FLUX.1-dev",
|
71 |
+
torch_dtype=torch.float16, # A100์์ FP16 ์ฌ์ฉ
|
72 |
use_safetensors=True
|
73 |
+
).to("cuda")
|
74 |
except Exception as e:
|
75 |
print(f"Error loading Flux pipeline: {e}")
|
76 |
return None
|
77 |
return flux_pipe
|
78 |
|
79 |
|
80 |
+
|
81 |
+
|
82 |
def free_memory():
|
83 |
"""๊ฐํ๋ ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ ํจ์"""
|
84 |
import gc
|
|
|
256 |
@spaces.GPU
|
257 |
def generate_image_from_text(prompt, height, width, guidance_scale, num_steps):
|
258 |
try:
|
259 |
+
free_memory()
|
260 |
|
261 |
# Flux ํ์ดํ๋ผ์ธ ๊ฐ์ ธ์ค๊ธฐ
|
262 |
flux_pipe = get_flux_pipe()
|
|
|
264 |
raise Exception("Failed to load Flux pipeline")
|
265 |
|
266 |
# ์ด๋ฏธ์ง ํฌ๊ธฐ ์ ํ
|
267 |
+
height = min(height, 1024) # A100์์๋ ๋ ํฐ ์ด๋ฏธ์ง ํ์ฉ
|
268 |
+
width = min(width, 1024)
|
269 |
|
270 |
# ํ๋กฌํํธ ์ฒ๋ฆฌ
|
271 |
base_prompt = "wbgmsst, 3D, white background"
|
272 |
translated_prompt = translate_if_korean(prompt)
|
273 |
final_prompt = f"{translated_prompt}, {base_prompt}"
|
274 |
|
275 |
+
with torch.cuda.amp.autocast(): # A100์์ ์๋ ํผํฉ ์ ๋ฐ๋ ์ฌ์ฉ
|
276 |
output = flux_pipe(
|
277 |
prompt=[final_prompt],
|
278 |
height=height,
|
279 |
width=width,
|
280 |
+
guidance_scale=guidance_scale,
|
281 |
+
num_inference_steps=num_steps,
|
282 |
+
generator=torch.Generator(device='cuda')
|
283 |
)
|
284 |
+
|
285 |
+
image = output.images[0]
|
286 |
|
287 |
+
free_memory()
|
288 |
return image
|
289 |
|
290 |
except Exception as e:
|
|
|
437 |
import warnings
|
438 |
warnings.filterwarnings('ignore')
|
439 |
|
440 |
+
# CUDA ์ค์ ํ์ธ
|
441 |
+
if torch.cuda.is_available():
|
442 |
+
print(f"Using GPU: {torch.cuda.get_device_name()}")
|
443 |
+
print(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
444 |
+
|
445 |
# ๋๋ ํ ๋ฆฌ ์์ฑ
|
446 |
os.makedirs(TMP_DIR, exist_ok=True)
|
447 |
|
448 |
# ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ
|
449 |
free_memory()
|
450 |
|
451 |
+
# ๋ชจ๋ธ ์ด๊ธฐํ
|
452 |
+
if not initialize_models():
|
453 |
+
print("Failed to initialize models")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
454 |
exit(1)
|
455 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
456 |
# Gradio ์ฑ ์คํ
|
457 |
+
demo.queue(max_size=2).launch( # ํ ํฌ๊ธฐ ์ฆ๊ฐ
|
458 |
+
share=True,
|
459 |
+
max_threads=4, # ์ค๋ ๋ ์ ์ฆ๊ฐ
|
460 |
+
show_error=True,
|
461 |
+
server_port=7860,
|
462 |
+
server_name="0.0.0.0"
|
463 |
+
)
|
|
|
|
|
|
|
|
|
|