MohamedRashad commited on
Commit
eecb045
·
1 Parent(s): 366fd1c

Refactor load_tokenizer function to include error handling and device optimizations; streamline model loading process and improve memory management

Browse files
Files changed (1) hide show
  1. app.py +42 -36
app.py CHANGED
@@ -57,7 +57,7 @@ def encode_prompt(text_tokenizer, text_encoder, prompt):
57
  text_features = text_encoder(input_ids=input_ids, attention_mask=mask)['last_hidden_state'].float()
58
  lens: List[int] = mask.sum(dim=-1).tolist()
59
  cu_seqlens_k = F.pad(mask.sum(dim=-1).to(dtype=torch.int32).cumsum_(0), (1, 0))
60
- Ltext = max(lens)
61
  kv_compact = []
62
  for len_i, feat_i in zip(lens, text_features.unbind(0)):
63
  kv_compact.append(feat_i[:len_i])
@@ -77,15 +77,40 @@ def save_slim_model(infinity_model_path, save_file=None, device='cpu', key='gpt_
77
  print('[Save slim model] done')
78
  return save_file
79
 
80
- def load_tokenizer(t5_path=''):
81
- print('[Loading tokenizer and text encoder]')
82
- tokenizer = AutoTokenizer.from_pretrained(t5_path, legacy=True)
83
- tokenizer.model_max_length = 512
84
- encoder = T5EncoderModel.from_pretrained(t5_path, torch_dtype=torch.float16)
85
- encoder.eval()
86
- encoder.to("cuda" if torch.cuda.is_available() else "cpu")
87
- encoder.requires_grad_(False)
88
- return tokenizer, encoder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  def load_infinity(
91
  rope2d_each_sa_layer,
@@ -154,8 +179,8 @@ def load_infinity(
154
  state_dict = torch.load(model_path, map_location=device)
155
  print(infinity_test.load_state_dict(state_dict))
156
 
157
- # Initialize random number generator on the correct device
158
- infinity_test.rng = torch.Generator(device=device)
159
 
160
  return infinity_test
161
 
@@ -315,29 +340,7 @@ def load_transformer(vae, args):
315
  model_path = args.model_path
316
 
317
  if args.checkpoint_type == 'torch':
318
- if osp.exists(args.cache_dir):
319
- local_model_path = osp.join(args.cache_dir, 'tmp', model_path.replace('/', '_'))
320
- else:
321
- local_model_path = model_path
322
-
323
- if args.enable_model_cache:
324
- slim_model_path = model_path.replace('ar-', 'slim-')
325
- local_slim_model_path = local_model_path.replace('ar-', 'slim-')
326
- os.makedirs(osp.dirname(local_slim_model_path), exist_ok=True)
327
- if not osp.exists(local_slim_model_path):
328
- if osp.exists(slim_model_path):
329
- shutil.copyfile(slim_model_path, local_slim_model_path)
330
- else:
331
- if not osp.exists(local_model_path):
332
- shutil.copyfile(model_path, local_model_path)
333
- save_slim_model(local_model_path, save_file=local_slim_model_path, device=device)
334
- if not osp.exists(slim_model_path):
335
- shutil.copyfile(local_slim_model_path, slim_model_path)
336
- os.remove(local_model_path)
337
- os.remove(model_path)
338
- slim_model_path = local_slim_model_path
339
- else:
340
- slim_model_path = model_path
341
  print(f'Loading checkpoint from {slim_model_path}')
342
  else:
343
  raise ValueError(f"Unsupported checkpoint_type: {args.checkpoint_type}")
@@ -465,10 +468,13 @@ args = argparse.Namespace(
465
  )
466
 
467
  # Load models
 
468
  text_tokenizer, text_encoder = load_tokenizer(t5_path="google/flan-t5-xl")
 
469
  vae = load_visual_tokenizer(args)
 
470
  infinity = load_transformer(vae, args)
471
-
472
 
473
  # Define the image generation function
474
  @spaces.GPU
 
57
  text_features = text_encoder(input_ids=input_ids, attention_mask=mask)['last_hidden_state'].float()
58
  lens: List[int] = mask.sum(dim=-1).tolist()
59
  cu_seqlens_k = F.pad(mask.sum(dim=-1).to(dtype=torch.int32).cumsum_(0), (1, 0))
60
+ Ltext = max(lens)
61
  kv_compact = []
62
  for len_i, feat_i in zip(lens, text_features.unbind(0)):
63
  kv_compact.append(feat_i[:len_i])
 
77
  print('[Save slim model] done')
78
  return save_file
79
 
80
+ def load_tokenizer(t5_path='google/flan-t5-xl'):
81
+ """
82
+ Load and configure the T5 tokenizer and encoder with optimizations.
83
+ """
84
+ try:
85
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
86
+ bf16_supported = device.type == 'cuda' and torch.cuda.is_bf16_supported()
87
+ dtype = torch.bfloat16 if bf16_supported else torch.float32
88
+
89
+ tokenizer = AutoTokenizer.from_pretrained(
90
+ t5_path,
91
+ legacy=True,
92
+ model_max_length=512,
93
+ use_fast=True,
94
+ )
95
+
96
+ if device.type == 'cuda':
97
+ torch.cuda.empty_cache()
98
+
99
+ encoder = T5EncoderModel.from_pretrained(
100
+ t5_path,
101
+ torch_dtype=dtype,
102
+ )
103
+
104
+ encoder.eval().requires_grad_(False).to(device)
105
+
106
+ if device.type == 'cuda' and not bf16_supported:
107
+ encoder.half()
108
+
109
+ return tokenizer, encoder
110
+
111
+ except Exception as e:
112
+ print(f"Error loading tokenizer/encoder: {str(e)}")
113
+ raise RuntimeError("Failed to initialize text models") from e
114
 
115
  def load_infinity(
116
  rope2d_each_sa_layer,
 
179
  state_dict = torch.load(model_path, map_location=device)
180
  print(infinity_test.load_state_dict(state_dict))
181
 
182
+ # # Initialize random number generator on the correct device
183
+ # infinity_test.rng = torch.Generator(device=device)
184
 
185
  return infinity_test
186
 
 
340
  model_path = args.model_path
341
 
342
  if args.checkpoint_type == 'torch':
343
+ slim_model_path = model_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  print(f'Loading checkpoint from {slim_model_path}')
345
  else:
346
  raise ValueError(f"Unsupported checkpoint_type: {args.checkpoint_type}")
 
468
  )
469
 
470
  # Load models
471
+ print(f"VRAM before forward pass: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB")
472
  text_tokenizer, text_encoder = load_tokenizer(t5_path="google/flan-t5-xl")
473
+ print(f"VRAM before forward pass: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB")
474
  vae = load_visual_tokenizer(args)
475
+ print(f"VRAM before forward pass: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB")
476
  infinity = load_transformer(vae, args)
477
+ print(f"VRAM before forward pass: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB")
478
 
479
  # Define the image generation function
480
  @spaces.GPU