NaFlex in timm

Community Article Published April 9, 2025

Amist all of the distractions I've been plodding away on some new bits of code -- finally tackling variable size image / aspect in timm via the approach taken by SigLIP2 NaFlex. Over 1 year ago worked on the NaViT approach but got blocked on the data pipeline. That's also the biggest challenge here, but not as tricky, no sequence packing is required. So why is the data loading a pain? Because of the canonical PyTorch data pipeline.

The way things are typically organized with a 'dataset', 'transforms', 'sampler', 'dataloader', 'collate' the start of the data-processing pipeline in dataset / transforms is disconnected from the sampler and the collation fn, and there's a dataloader worker process boundary between some of these. There's also limited built in mechanisms to synchronize state between those entities... so, if you want to vary sequence length, and thus batch size per batch, it's ugly. I made several false starts and then decided, screw it, everything goes in the dataset.

Huh? Yes, batching and sampling in the dataloader is disabled, and there's a dataset wrapper (itself an IterableDataset) that determines seq len & batch size, samples indices (distributed aware), runs the transform, then batches and does the bulk of patchification and collation itself.

I currently have what I consider an 'alpha' PR in timm. Still lots to test & polish, especially distributed support. What's implemented:

  • A NaFlex style ViT that can should be able to flip between NaFlex sequence and fixed mode.

  • Optimized naflex position embed interpolation in the model, at least improved on whats in existing PyTorch impl.

  • A set of 'sequence' oriented transforms that constrain image sizes to target sequence lengths.

  • A dataset wrapper for map-style datasets, an iterable wrapper on the way.

  • Dataset wrappers control seq len vs batch size tradeoffs and try to maximize GPU utilization so that as seq length decreases from the max, batch size increases. This appears to be working well in PyTorch w/ GPU, incl w/ torch compile as the set of combos is limited. I think XLA would cope?

  • Modifications to training scripts to handle variable batch sizes per step, including options to scale gradients and handle different batch size per distributed rank.

Reference: SigLIP-2 NaFlex details discussed in paper: https://arxiv.org/html/2502.14786

Community

Your need to confirm your account before you can post a new comment.

Sign up or log in to comment