Spaces:
Running
on
Zero
Running
on
Zero
MohamedRashad
commited on
Commit
·
366fd1c
1
Parent(s):
27e1ebb
Enable bf16 in load_infinity function and enhance transform function with type hints and error handling; refactor joint_vi_vae_encode_decode for improved performance and error management
Browse files
app.py
CHANGED
@@ -102,7 +102,7 @@ def load_infinity(
|
|
102 |
text_channels=2048,
|
103 |
apply_spatial_patchify=0,
|
104 |
use_flex_attn=False,
|
105 |
-
bf16=
|
106 |
):
|
107 |
print('[Loading Infinity]')
|
108 |
|
@@ -156,45 +156,137 @@ def load_infinity(
|
|
156 |
|
157 |
# Initialize random number generator on the correct device
|
158 |
infinity_test.rng = torch.Generator(device=device)
|
159 |
-
|
160 |
return infinity_test
|
161 |
|
162 |
-
def transform(pil_img, tgt_h, tgt_w):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
width, height = pil_img.size
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
#
|
172 |
-
arr = np.array(pil_img)
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
199 |
def load_visual_tokenizer(args):
|
200 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
@@ -219,29 +311,26 @@ def load_visual_tokenizer(args):
|
|
219 |
return vae
|
220 |
|
221 |
def load_transformer(vae, args):
|
|
|
222 |
model_path = args.model_path
|
223 |
-
|
224 |
-
|
225 |
if osp.exists(args.cache_dir):
|
226 |
local_model_path = osp.join(args.cache_dir, 'tmp', model_path.replace('/', '_'))
|
227 |
else:
|
228 |
local_model_path = model_path
|
|
|
229 |
if args.enable_model_cache:
|
230 |
slim_model_path = model_path.replace('ar-', 'slim-')
|
231 |
local_slim_model_path = local_model_path.replace('ar-', 'slim-')
|
232 |
os.makedirs(osp.dirname(local_slim_model_path), exist_ok=True)
|
233 |
-
print(f'model_path: {model_path}, slim_model_path: {slim_model_path}')
|
234 |
-
print(f'local_model_path: {local_model_path}, local_slim_model_path: {local_slim_model_path}')
|
235 |
if not osp.exists(local_slim_model_path):
|
236 |
if osp.exists(slim_model_path):
|
237 |
-
print(f'copy {slim_model_path} to {local_slim_model_path}')
|
238 |
shutil.copyfile(slim_model_path, local_slim_model_path)
|
239 |
else:
|
240 |
if not osp.exists(local_model_path):
|
241 |
-
print(f'copy {model_path} to {local_model_path}')
|
242 |
shutil.copyfile(model_path, local_model_path)
|
243 |
save_slim_model(local_model_path, save_file=local_slim_model_path, device=device)
|
244 |
-
print(f'copy {local_slim_model_path} to {slim_model_path}')
|
245 |
if not osp.exists(slim_model_path):
|
246 |
shutil.copyfile(local_slim_model_path, slim_model_path)
|
247 |
os.remove(local_model_path)
|
@@ -249,33 +338,35 @@ def load_transformer(vae, args):
|
|
249 |
slim_model_path = local_slim_model_path
|
250 |
else:
|
251 |
slim_model_path = model_path
|
252 |
-
print(f'
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
|
|
|
|
268 |
infinity = load_infinity(
|
269 |
-
rope2d_each_sa_layer=args.rope2d_each_sa_layer,
|
270 |
rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
|
271 |
use_scale_schedule_embedding=args.use_scale_schedule_embedding,
|
272 |
pn=args.pn,
|
273 |
-
use_bit_label=args.use_bit_label,
|
274 |
-
add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block,
|
275 |
-
model_path=slim_model_path,
|
276 |
-
scale_schedule=None,
|
277 |
-
vae=vae,
|
278 |
-
device=
|
279 |
model_kwargs=kwargs_model,
|
280 |
text_channels=args.text_channels,
|
281 |
apply_spatial_patchify=args.apply_spatial_patchify,
|
|
|
102 |
text_channels=2048,
|
103 |
apply_spatial_patchify=0,
|
104 |
use_flex_attn=False,
|
105 |
+
bf16=True,
|
106 |
):
|
107 |
print('[Loading Infinity]')
|
108 |
|
|
|
156 |
|
157 |
# Initialize random number generator on the correct device
|
158 |
infinity_test.rng = torch.Generator(device=device)
|
159 |
+
|
160 |
return infinity_test
|
161 |
|
162 |
+
def transform(pil_img: PImage.Image, tgt_h: int, tgt_w: int) -> torch.Tensor:
|
163 |
+
"""
|
164 |
+
Transform a PIL image to a tensor with target dimensions while preserving aspect ratio.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
pil_img: PIL Image to transform
|
168 |
+
tgt_h: Target height
|
169 |
+
tgt_w: Target width
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
torch.Tensor: Normalized tensor image in range [-1, 1]
|
173 |
+
"""
|
174 |
+
if not isinstance(pil_img, PImage.Image):
|
175 |
+
raise TypeError("Input must be a PIL Image")
|
176 |
+
|
177 |
+
if tgt_h <= 0 or tgt_w <= 0:
|
178 |
+
raise ValueError("Target dimensions must be positive")
|
179 |
+
|
180 |
+
# Calculate resize dimensions preserving aspect ratio
|
181 |
width, height = pil_img.size
|
182 |
+
scale = min(tgt_w / width, tgt_h / height)
|
183 |
+
new_width = int(width * scale)
|
184 |
+
new_height = int(height * scale)
|
185 |
+
|
186 |
+
# Resize using LANCZOS for best quality
|
187 |
+
pil_img = pil_img.resize((new_width, new_height), resample=PImage.LANCZOS)
|
188 |
+
|
189 |
+
# Create center crop
|
190 |
+
arr = np.array(pil_img, dtype=np.uint8)
|
191 |
+
|
192 |
+
# Calculate crop coordinates
|
193 |
+
y1 = max(0, (new_height - tgt_h) // 2)
|
194 |
+
x1 = max(0, (new_width - tgt_w) // 2)
|
195 |
+
y2 = y1 + tgt_h
|
196 |
+
x2 = x1 + tgt_w
|
197 |
+
|
198 |
+
# Crop and convert to tensor
|
199 |
+
arr = arr[y1:y2, x1:x2]
|
200 |
+
|
201 |
+
# Convert to normalized tensor in one step
|
202 |
+
return torch.from_numpy(arr.transpose(2, 0, 1)).float().div_(127.5).sub_(1)
|
203 |
+
|
204 |
+
def joint_vi_vae_encode_decode(
|
205 |
+
vae: 'VAEModel', # Type hint would be more specific with actual VAE class
|
206 |
+
image_path: str | Path,
|
207 |
+
scale_schedule: List[tuple],
|
208 |
+
device: torch.device | str,
|
209 |
+
tgt_h: int,
|
210 |
+
tgt_w: int
|
211 |
+
) -> tuple[np.ndarray, np.ndarray, torch.Tensor]:
|
212 |
+
"""
|
213 |
+
Encode and decode an image using a VAE model with joint visual-infinity processing.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
vae: The VAE model instance
|
217 |
+
image_path: Path to input image
|
218 |
+
scale_schedule: List of scale tuples for processing
|
219 |
+
device: Target device for computation
|
220 |
+
tgt_h: Target height for the image
|
221 |
+
tgt_w: Target width for the image
|
222 |
+
|
223 |
+
Returns:
|
224 |
+
tuple containing:
|
225 |
+
- Original image as numpy array (uint8)
|
226 |
+
- Reconstructed image as numpy array (uint8)
|
227 |
+
- Bit indices tensor
|
228 |
+
|
229 |
+
Raises:
|
230 |
+
FileNotFoundError: If image file doesn't exist
|
231 |
+
RuntimeError: If VAE processing fails
|
232 |
+
"""
|
233 |
+
try:
|
234 |
+
# Validate input path
|
235 |
+
if not Path(image_path).exists():
|
236 |
+
raise FileNotFoundError(f"Image not found at {image_path}")
|
237 |
+
|
238 |
+
# Load and preprocess image
|
239 |
+
pil_image = Image.open(image_path).convert('RGB')
|
240 |
+
inp = transform(pil_image, tgt_h, tgt_w)
|
241 |
+
inp = inp.unsqueeze(0).to(device)
|
242 |
+
|
243 |
+
# Normalize scale schedule
|
244 |
+
scale_schedule = [(s[0], s[1], s[2]) for s in scale_schedule]
|
245 |
+
|
246 |
+
# Decide whether to use CPU or GPU
|
247 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
248 |
+
|
249 |
+
# Time the encoding/decoding operations
|
250 |
+
with torch.amp.autocast(device, dtype=torch.bfloat16):
|
251 |
+
encode_start = time.perf_counter()
|
252 |
+
h, z, _, all_bit_indices, _, _ = vae.encode(
|
253 |
+
inp,
|
254 |
+
scale_schedule=scale_schedule
|
255 |
+
)
|
256 |
+
encode_time = time.perf_counter() - encode_start
|
257 |
+
|
258 |
+
decode_start = time.perf_counter()
|
259 |
+
recons_img = vae.decode(z)[0]
|
260 |
+
decode_time = time.perf_counter() - decode_start
|
261 |
+
|
262 |
+
# Process reconstruction
|
263 |
+
if recons_img.dim() == 4:
|
264 |
+
recons_img = recons_img.squeeze(1)
|
265 |
+
|
266 |
+
# Log performance metrics
|
267 |
+
print(f'VAE encode: {encode_time:.2f}s, decode: {decode_time:.2f}s')
|
268 |
+
print(f'Reconstruction shape: {recons_img.shape}, z shape: {z.shape}')
|
269 |
+
|
270 |
+
# Convert to numpy arrays efficiently
|
271 |
+
recons_img = (recons_img.add(1).div(2)
|
272 |
+
.permute(1, 2, 0)
|
273 |
+
.mul(255)
|
274 |
+
.cpu()
|
275 |
+
.numpy()
|
276 |
+
.astype(np.uint8))
|
277 |
+
|
278 |
+
gt_img = (inp[0].add(1).div(2)
|
279 |
+
.permute(1, 2, 0)
|
280 |
+
.mul(255)
|
281 |
+
.cpu()
|
282 |
+
.numpy()
|
283 |
+
.astype(np.uint8))
|
284 |
+
|
285 |
+
return gt_img, recons_img, all_bit_indices
|
286 |
+
|
287 |
+
except Exception as e:
|
288 |
+
print(f"Error in VAE processing: {str(e)}")
|
289 |
+
raise RuntimeError("VAE processing failed") from e
|
290 |
|
291 |
def load_visual_tokenizer(args):
|
292 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
311 |
return vae
|
312 |
|
313 |
def load_transformer(vae, args):
|
314 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
315 |
model_path = args.model_path
|
316 |
+
|
317 |
+
if args.checkpoint_type == 'torch':
|
318 |
if osp.exists(args.cache_dir):
|
319 |
local_model_path = osp.join(args.cache_dir, 'tmp', model_path.replace('/', '_'))
|
320 |
else:
|
321 |
local_model_path = model_path
|
322 |
+
|
323 |
if args.enable_model_cache:
|
324 |
slim_model_path = model_path.replace('ar-', 'slim-')
|
325 |
local_slim_model_path = local_model_path.replace('ar-', 'slim-')
|
326 |
os.makedirs(osp.dirname(local_slim_model_path), exist_ok=True)
|
|
|
|
|
327 |
if not osp.exists(local_slim_model_path):
|
328 |
if osp.exists(slim_model_path):
|
|
|
329 |
shutil.copyfile(slim_model_path, local_slim_model_path)
|
330 |
else:
|
331 |
if not osp.exists(local_model_path):
|
|
|
332 |
shutil.copyfile(model_path, local_model_path)
|
333 |
save_slim_model(local_model_path, save_file=local_slim_model_path, device=device)
|
|
|
334 |
if not osp.exists(slim_model_path):
|
335 |
shutil.copyfile(local_slim_model_path, slim_model_path)
|
336 |
os.remove(local_model_path)
|
|
|
338 |
slim_model_path = local_slim_model_path
|
339 |
else:
|
340 |
slim_model_path = model_path
|
341 |
+
print(f'Loading checkpoint from {slim_model_path}')
|
342 |
+
else:
|
343 |
+
raise ValueError(f"Unsupported checkpoint_type: {args.checkpoint_type}")
|
344 |
+
|
345 |
+
model_configs = {
|
346 |
+
'infinity_2b': dict(depth=32, embed_dim=2048, num_heads=16, drop_path_rate=0.1, mlp_ratio=4, block_chunks=8),
|
347 |
+
'infinity_layer12': dict(depth=12, embed_dim=768, num_heads=8, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4),
|
348 |
+
'infinity_layer16': dict(depth=16, embed_dim=1152, num_heads=12, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4),
|
349 |
+
'infinity_layer24': dict(depth=24, embed_dim=1536, num_heads=16, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4),
|
350 |
+
'infinity_layer32': dict(depth=32, embed_dim=2080, num_heads=20, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4),
|
351 |
+
'infinity_layer40': dict(depth=40, embed_dim=2688, num_heads=24, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4),
|
352 |
+
'infinity_layer48': dict(depth=48, embed_dim=3360, num_heads=28, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4),
|
353 |
+
}
|
354 |
+
|
355 |
+
kwargs_model = model_configs.get(args.model_type)
|
356 |
+
if kwargs_model is None:
|
357 |
+
raise ValueError(f"Unsupported model_type: {args.model_type}")
|
358 |
+
|
359 |
infinity = load_infinity(
|
360 |
+
rope2d_each_sa_layer=args.rope2d_each_sa_layer,
|
361 |
rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
|
362 |
use_scale_schedule_embedding=args.use_scale_schedule_embedding,
|
363 |
pn=args.pn,
|
364 |
+
use_bit_label=args.use_bit_label,
|
365 |
+
add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block,
|
366 |
+
model_path=slim_model_path,
|
367 |
+
scale_schedule=None,
|
368 |
+
vae=vae,
|
369 |
+
device=device,
|
370 |
model_kwargs=kwargs_model,
|
371 |
text_channels=args.text_channels,
|
372 |
apply_spatial_patchify=args.apply_spatial_patchify,
|