OMG_Seg / ext /open_clip /zero_shot_classifier.py
HarborYuan's picture
add omg code
b34d1d6
raw
history blame
No virus
4.35 kB
from functools import partial
from itertools import islice
from typing import Callable, List, Optional, Sequence, Union
import torch
import torch.nn.functional as F
def batched(iterable, n):
"""Batch data into lists of length *n*. The last batch may be shorter.
NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl
"""
it = iter(iterable)
while True:
batch = list(islice(it, n))
if not batch:
break
yield batch
def build_zero_shot_classifier(
model,
tokenizer,
classnames: Sequence[str],
templates: Sequence[Union[Callable, str]],
num_classes_per_batch: Optional[int] = 10,
device: Union[str, torch.device] = 'cpu',
use_tqdm: bool = False,
):
""" Build zero-shot classifier weights by iterating over class names in batches
Args:
model: CLIP model instance
tokenizer: CLIP tokenizer instance
classnames: A sequence of class (label) names
templates: A sequence of callables or format() friendly strings to produce templates per class name
num_classes_per_batch: The number of classes to batch together in each forward, all if None
device: Device to use.
use_tqdm: Enable TQDM progress bar.
"""
assert isinstance(templates, Sequence) and len(templates) > 0
assert isinstance(classnames, Sequence) and len(classnames) > 0
use_format = isinstance(templates[0], str)
num_templates = len(templates)
num_classes = len(classnames)
if use_tqdm:
import tqdm
num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1)
iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch)
else:
iter_wrap = iter
def _process_batch(batch_classnames):
num_batch_classes = len(batch_classnames)
texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates]
texts = tokenizer(texts).to(device)
class_embeddings = F.normalize(model.encode_text(texts), dim=-1)
class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1)
class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True)
class_embeddings = class_embeddings.T
return class_embeddings
with torch.no_grad():
if num_classes_per_batch:
batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))]
zeroshot_weights = torch.cat(batched_embeds, dim=1)
else:
zeroshot_weights = _process_batch(classnames)
return zeroshot_weights
def build_zero_shot_classifier_legacy(
model,
tokenizer,
classnames: Sequence[str],
templates: Sequence[Union[Callable, str]],
device: Union[str, torch.device] = 'cpu',
use_tqdm: bool = False,
):
""" Build zero-shot classifier weights by iterating over class names 1 by 1
Args:
model: CLIP model instance
tokenizer: CLIP tokenizer instance
classnames: A sequence of class (label) names
templates: A sequence of callables or format() friendly strings to produce templates per class name
device: Device to use.
use_tqdm: Enable TQDM progress bar.
"""
assert isinstance(templates, Sequence) and len(templates) > 0
assert isinstance(classnames, Sequence) and len(classnames) > 0
if use_tqdm:
import tqdm
iter_wrap = tqdm.tqdm
else:
iter_wrap = iter
use_format = isinstance(templates[0], str)
with torch.no_grad():
zeroshot_weights = []
for classname in iter_wrap(classnames):
texts = [template.format(classname) if use_format else template(classname) for template in templates]
texts = tokenizer(texts).to(device) # tokenize
class_embeddings = model.encode_text(texts)
class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
class_embedding /= class_embedding.norm()
zeroshot_weights.append(class_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
return zeroshot_weights