ginipick commited on
Commit
abde71d
ยท
verified ยท
1 Parent(s): 9f80462

Update models/util.py

Browse files
Files changed (1) hide show
  1. models/util.py +10 -4
models/util.py CHANGED
@@ -420,10 +420,16 @@ def load_flow_model(
420
  # print_load_warning(missing, unexpected)
421
  return model
422
 
423
-
424
- def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
425
- # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
426
- return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
 
 
 
 
 
 
427
 
428
 
429
  def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
 
420
  # print_load_warning(missing, unexpected)
421
  return model
422
 
423
+ # ์•ฝ 426๋ฒˆ ์ค„์— ์œ„์น˜ํ•œ load_t5 ํ•จ์ˆ˜๋ฅผ ์ฐพ์•„ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ˆ˜์ •
424
+ def load_t5(device, max_length=256):
425
+ try:
426
+ # ์›๋ž˜ ์ฝ”๋“œ: ๋Œ€ํ˜• T5-XXL ๋ชจ๋ธ ๋กœ๋“œ ์‹œ๋„
427
+ return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
428
+ except Exception as e:
429
+ print(f"T5-XXL ๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ: {str(e)}")
430
+ print("๋” ์ž‘์€ T5 ๋ชจ๋ธ๋กœ ๋Œ€์ฒดํ•ฉ๋‹ˆ๋‹ค...")
431
+ # ๋” ์ž‘์€ ๋ชจ๋ธ๋กœ ๋Œ€์ฒด
432
+ return HFEmbedder("google/t5-v1_1-large", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
433
 
434
 
435
  def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: