Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
eecb045
1
Parent(s):
366fd1c
Refactor load_tokenizer function to include error handling and device optimizations; streamline model loading process and improve memory management
Browse files
app.py
CHANGED
@@ -57,7 +57,7 @@ def encode_prompt(text_tokenizer, text_encoder, prompt):
|
|
57 |
text_features = text_encoder(input_ids=input_ids, attention_mask=mask)['last_hidden_state'].float()
|
58 |
lens: List[int] = mask.sum(dim=-1).tolist()
|
59 |
cu_seqlens_k = F.pad(mask.sum(dim=-1).to(dtype=torch.int32).cumsum_(0), (1, 0))
|
60 |
-
Ltext = max(lens)
|
61 |
kv_compact = []
|
62 |
for len_i, feat_i in zip(lens, text_features.unbind(0)):
|
63 |
kv_compact.append(feat_i[:len_i])
|
@@ -77,15 +77,40 @@ def save_slim_model(infinity_model_path, save_file=None, device='cpu', key='gpt_
|
|
77 |
print('[Save slim model] done')
|
78 |
return save_file
|
79 |
|
80 |
-
def load_tokenizer(t5_path=''):
|
81 |
-
|
82 |
-
tokenizer
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
def load_infinity(
|
91 |
rope2d_each_sa_layer,
|
@@ -154,8 +179,8 @@ def load_infinity(
|
|
154 |
state_dict = torch.load(model_path, map_location=device)
|
155 |
print(infinity_test.load_state_dict(state_dict))
|
156 |
|
157 |
-
# Initialize random number generator on the correct device
|
158 |
-
infinity_test.rng = torch.Generator(device=device)
|
159 |
|
160 |
return infinity_test
|
161 |
|
@@ -315,29 +340,7 @@ def load_transformer(vae, args):
|
|
315 |
model_path = args.model_path
|
316 |
|
317 |
if args.checkpoint_type == 'torch':
|
318 |
-
|
319 |
-
local_model_path = osp.join(args.cache_dir, 'tmp', model_path.replace('/', '_'))
|
320 |
-
else:
|
321 |
-
local_model_path = model_path
|
322 |
-
|
323 |
-
if args.enable_model_cache:
|
324 |
-
slim_model_path = model_path.replace('ar-', 'slim-')
|
325 |
-
local_slim_model_path = local_model_path.replace('ar-', 'slim-')
|
326 |
-
os.makedirs(osp.dirname(local_slim_model_path), exist_ok=True)
|
327 |
-
if not osp.exists(local_slim_model_path):
|
328 |
-
if osp.exists(slim_model_path):
|
329 |
-
shutil.copyfile(slim_model_path, local_slim_model_path)
|
330 |
-
else:
|
331 |
-
if not osp.exists(local_model_path):
|
332 |
-
shutil.copyfile(model_path, local_model_path)
|
333 |
-
save_slim_model(local_model_path, save_file=local_slim_model_path, device=device)
|
334 |
-
if not osp.exists(slim_model_path):
|
335 |
-
shutil.copyfile(local_slim_model_path, slim_model_path)
|
336 |
-
os.remove(local_model_path)
|
337 |
-
os.remove(model_path)
|
338 |
-
slim_model_path = local_slim_model_path
|
339 |
-
else:
|
340 |
-
slim_model_path = model_path
|
341 |
print(f'Loading checkpoint from {slim_model_path}')
|
342 |
else:
|
343 |
raise ValueError(f"Unsupported checkpoint_type: {args.checkpoint_type}")
|
@@ -465,10 +468,13 @@ args = argparse.Namespace(
|
|
465 |
)
|
466 |
|
467 |
# Load models
|
|
|
468 |
text_tokenizer, text_encoder = load_tokenizer(t5_path="google/flan-t5-xl")
|
|
|
469 |
vae = load_visual_tokenizer(args)
|
|
|
470 |
infinity = load_transformer(vae, args)
|
471 |
-
|
472 |
|
473 |
# Define the image generation function
|
474 |
@spaces.GPU
|
|
|
57 |
text_features = text_encoder(input_ids=input_ids, attention_mask=mask)['last_hidden_state'].float()
|
58 |
lens: List[int] = mask.sum(dim=-1).tolist()
|
59 |
cu_seqlens_k = F.pad(mask.sum(dim=-1).to(dtype=torch.int32).cumsum_(0), (1, 0))
|
60 |
+
Ltext = max(lens)
|
61 |
kv_compact = []
|
62 |
for len_i, feat_i in zip(lens, text_features.unbind(0)):
|
63 |
kv_compact.append(feat_i[:len_i])
|
|
|
77 |
print('[Save slim model] done')
|
78 |
return save_file
|
79 |
|
80 |
+
def load_tokenizer(t5_path='google/flan-t5-xl'):
|
81 |
+
"""
|
82 |
+
Load and configure the T5 tokenizer and encoder with optimizations.
|
83 |
+
"""
|
84 |
+
try:
|
85 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
86 |
+
bf16_supported = device.type == 'cuda' and torch.cuda.is_bf16_supported()
|
87 |
+
dtype = torch.bfloat16 if bf16_supported else torch.float32
|
88 |
+
|
89 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
90 |
+
t5_path,
|
91 |
+
legacy=True,
|
92 |
+
model_max_length=512,
|
93 |
+
use_fast=True,
|
94 |
+
)
|
95 |
+
|
96 |
+
if device.type == 'cuda':
|
97 |
+
torch.cuda.empty_cache()
|
98 |
+
|
99 |
+
encoder = T5EncoderModel.from_pretrained(
|
100 |
+
t5_path,
|
101 |
+
torch_dtype=dtype,
|
102 |
+
)
|
103 |
+
|
104 |
+
encoder.eval().requires_grad_(False).to(device)
|
105 |
+
|
106 |
+
if device.type == 'cuda' and not bf16_supported:
|
107 |
+
encoder.half()
|
108 |
+
|
109 |
+
return tokenizer, encoder
|
110 |
+
|
111 |
+
except Exception as e:
|
112 |
+
print(f"Error loading tokenizer/encoder: {str(e)}")
|
113 |
+
raise RuntimeError("Failed to initialize text models") from e
|
114 |
|
115 |
def load_infinity(
|
116 |
rope2d_each_sa_layer,
|
|
|
179 |
state_dict = torch.load(model_path, map_location=device)
|
180 |
print(infinity_test.load_state_dict(state_dict))
|
181 |
|
182 |
+
# # Initialize random number generator on the correct device
|
183 |
+
# infinity_test.rng = torch.Generator(device=device)
|
184 |
|
185 |
return infinity_test
|
186 |
|
|
|
340 |
model_path = args.model_path
|
341 |
|
342 |
if args.checkpoint_type == 'torch':
|
343 |
+
slim_model_path = model_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
print(f'Loading checkpoint from {slim_model_path}')
|
345 |
else:
|
346 |
raise ValueError(f"Unsupported checkpoint_type: {args.checkpoint_type}")
|
|
|
468 |
)
|
469 |
|
470 |
# Load models
|
471 |
+
print(f"VRAM before forward pass: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB")
|
472 |
text_tokenizer, text_encoder = load_tokenizer(t5_path="google/flan-t5-xl")
|
473 |
+
print(f"VRAM before forward pass: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB")
|
474 |
vae = load_visual_tokenizer(args)
|
475 |
+
print(f"VRAM before forward pass: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB")
|
476 |
infinity = load_transformer(vae, args)
|
477 |
+
print(f"VRAM before forward pass: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB")
|
478 |
|
479 |
# Define the image generation function
|
480 |
@spaces.GPU
|