|
from enum import Enum |
|
import subprocess |
|
import sys |
|
import shutil |
|
import torch |
|
import torch.distributed as dist |
|
from torchvision.transforms import functional as F |
|
from torchvision import transforms as T |
|
from transformers import AutoFeatureExtractor |
|
from PIL import Image, ImageDraw, ImageFont, ImageOps |
|
import requests |
|
from io import BytesIO |
|
|
|
import random |
|
|
|
|
|
def dump_git_status(out_file=sys.stdout, exclude_file_patterns=['*.ipynb', '*.th', '*.sh', '*.txt', '*.json']): |
|
"""Logs git status to stdout.""" |
|
subprocess.call('git rev-parse HEAD', shell=True, stdout=out_file) |
|
subprocess.call('echo', shell=True, stdout=out_file) |
|
exclude_string = '' |
|
subprocess.call('git --no-pager diff -- . {}'.format(exclude_string), shell=True, stdout=out_file) |
|
|
|
|
|
def get_image_from_url(url: str): |
|
response = requests.get(url) |
|
img = Image.open(BytesIO(response.content)) |
|
img = img.resize((224, 224)) |
|
img = img.convert('RGB') |
|
return img |
|
|
|
|
|
def truncate_caption(caption: str) -> str: |
|
"""Truncate captions at periods and newlines.""" |
|
trunc_index = caption.find('\n') + 1 |
|
if trunc_index <= 0: |
|
trunc_index = caption.find('.') + 1 |
|
caption = caption[:trunc_index] |
|
return caption |
|
|
|
|
|
def pad_to_size(x, size=256): |
|
delta_w = size - x.size[0] |
|
delta_h = size - x.size[1] |
|
padding = ( |
|
delta_w // 2, |
|
delta_h // 2, |
|
delta_w - (delta_w // 2), |
|
delta_h - (delta_h // 2), |
|
) |
|
new_im = ImageOps.expand(x, padding) |
|
return new_im |
|
|
|
|
|
class RandCropResize(object): |
|
|
|
""" |
|
Randomly crops, then randomly resizes, then randomly crops again, an image. Mirroring the augmentations from https://arxiv.org/abs/2102.12092 |
|
""" |
|
|
|
def __init__(self, target_size): |
|
self.target_size = target_size |
|
|
|
def __call__(self, img): |
|
img = pad_to_size(img, self.target_size) |
|
d_min = min(img.size) |
|
img = T.RandomCrop(size=d_min)(img) |
|
t_min = min(d_min, round(9 / 8 * self.target_size)) |
|
t_max = min(d_min, round(12 / 8 * self.target_size)) |
|
t = random.randint(t_min, t_max + 1) |
|
img = T.Resize(t)(img) |
|
if min(img.size) < 256: |
|
img = T.Resize(256)(img) |
|
return T.RandomCrop(size=self.target_size)(img) |
|
|
|
|
|
class SquarePad(object): |
|
"""Pads image to square. |
|
From https://discuss.pytorch.org/t/how-to-resize-and-pad-in-a-torchvision-transforms-compose/71850/9 |
|
""" |
|
def __call__(self, image): |
|
max_wh = max(image.size) |
|
p_left, p_top = [(max_wh - s) // 2 for s in image.size] |
|
p_right, p_bottom = [max_wh - (s+pad) for s, pad in zip(image.size, [p_left, p_top])] |
|
padding = (p_left, p_top, p_right, p_bottom) |
|
return F.pad(image, padding, 0, 'constant') |
|
|
|
|
|
def create_image_of_text(text: str, width: int = 224, nrows: int = 2, color=(255, 255, 255), font=None) -> torch.Tensor: |
|
"""Creates a (3, nrows * 14, width) image of text. |
|
|
|
Returns: |
|
cap_img: (3, 14 * nrows, width) image of wrapped text. |
|
""" |
|
height = 12 |
|
padding = 5 |
|
effective_width = width - 2 * padding |
|
|
|
cap_img = Image.new('RGB', (effective_width * nrows, height), color = (0, 0, 0)) |
|
draw = ImageDraw.Draw(cap_img) |
|
draw.text((0, 0), text, color, font=font or ImageFont.load_default()) |
|
cap_img = F.convert_image_dtype(F.pil_to_tensor(cap_img), torch.float32) |
|
cap_img = torch.split(cap_img, effective_width, dim=-1) |
|
cap_img = torch.cat(cap_img, dim=1) |
|
|
|
cap_img = torch.nn.functional.pad(cap_img, [padding, padding, 0, padding]) |
|
return cap_img |
|
|
|
|
|
def get_feature_extractor_for_model(model_name: str, image_size: int = 224, train: bool = True): |
|
print(f'Using HuggingFace AutoFeatureExtractor for {model_name}.') |
|
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) |
|
return feature_extractor |
|
|
|
|
|
def get_pixel_values_for_model(feature_extractor, img): |
|
pixel_values = feature_extractor( |
|
img.convert('RGB'), |
|
return_tensors="pt").pixel_values[0, ...] |
|
return pixel_values |
|
|
|
|
|
def save_checkpoint(state, is_best, filename='checkpoint'): |
|
torch.save(state, filename + '.pth.tar') |
|
if is_best: |
|
shutil.copyfile(filename + '.pth.tar', filename + '_best.pth.tar') |
|
|
|
|
|
def accuracy(output, target, padding, topk=(1,)): |
|
"""Computes the accuracy over the k top predictions for the specified values of k""" |
|
with torch.no_grad(): |
|
maxk = max(topk) |
|
if output.shape[-1] < maxk: |
|
print(f"[WARNING] Less than {maxk} predictions available. Using {output.shape[-1]} for topk.") |
|
|
|
maxk = min(maxk, output.shape[-1]) |
|
batch_size = target.size(0) |
|
|
|
|
|
_, pred = output.topk(maxk, -1, True, True) |
|
|
|
mask = (target != padding).type(target.dtype) |
|
target_expand = target[..., None].expand_as(pred) |
|
correct = pred.eq(target_expand) |
|
correct = correct * mask[..., None].expand_as(correct) |
|
|
|
res = [] |
|
for k in topk: |
|
correct_k = correct[..., :k].reshape(-1).float().sum(0, keepdim=True) |
|
res.append(correct_k.mul_(100.0 / mask.sum())) |
|
return res |
|
|
|
|
|
def get_params_count(model, max_name_len: int = 60): |
|
params = [(name[:max_name_len], p.numel(), str(tuple(p.shape)), p.requires_grad) for name, p in model.named_parameters()] |
|
total_trainable_params = sum([x[1] for x in params if x[-1]]) |
|
total_nontrainable_params = sum([x[1] for x in params if not x[-1]]) |
|
return params, total_trainable_params, total_nontrainable_params |
|
|
|
|
|
def get_params_count_str(model, max_name_len: int = 60): |
|
padding = 70 |
|
params, total_trainable_params, total_nontrainable_params = get_params_count(model, max_name_len) |
|
param_counts_text = '' |
|
param_counts_text += '=' * (max_name_len + padding) + '\n' |
|
param_counts_text += f'| {"Module":<{max_name_len}} | {"Trainable":<10} | {"Shape":>15} | {"Param Count":>12} |\n' |
|
param_counts_text += '-' * (max_name_len + padding) + '\n' |
|
for name, param_count, shape, trainable in params: |
|
param_counts_text += f'| {name:<{max_name_len}} | {"True" if trainable else "False":<10} | {shape:>15} | {param_count:>12,} |\n' |
|
param_counts_text += '-' * (max_name_len + padding) + '\n' |
|
param_counts_text += f'| {"Total trainable params":<{max_name_len}} | {"":<10} | {"":<15} | {total_trainable_params:>12,} |\n' |
|
param_counts_text += f'| {"Total non-trainable params":<{max_name_len}} | {"":<10} | {"":<15} | {total_nontrainable_params:>12,} |\n' |
|
param_counts_text += '=' * (max_name_len + padding) + '\n' |
|
return param_counts_text |
|
|
|
|
|
class Summary(Enum): |
|
NONE = 0 |
|
AVERAGE = 1 |
|
SUM = 2 |
|
COUNT = 3 |
|
|
|
|
|
class ProgressMeter(object): |
|
def __init__(self, num_batches, meters, prefix=""): |
|
self.batch_fmtstr = self._get_batch_fmtstr(num_batches) |
|
self.meters = meters |
|
self.prefix = prefix |
|
|
|
def display(self, batch): |
|
entries = [self.prefix + self.batch_fmtstr.format(batch)] |
|
entries += [str(meter) for meter in self.meters] |
|
print('\t'.join(entries)) |
|
|
|
def display_summary(self): |
|
entries = [" *"] |
|
entries += [meter.summary() for meter in self.meters] |
|
print(' '.join(entries)) |
|
|
|
def _get_batch_fmtstr(self, num_batches): |
|
num_digits = len(str(num_batches // 1)) |
|
fmt = '{:' + str(num_digits) + 'd}' |
|
return '[' + fmt + '/' + fmt.format(num_batches) + ']' |
|
|
|
|
|
class AverageMeter(object): |
|
"""Computes and stores the average and current value""" |
|
def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE): |
|
self.name = name |
|
self.fmt = fmt |
|
self.summary_type = summary_type |
|
self.reset() |
|
|
|
def reset(self): |
|
self.val = 0 |
|
self.avg = 0 |
|
self.sum = 0 |
|
self.count = 0 |
|
|
|
def update(self, val, n=1): |
|
self.val = val |
|
self.sum += val * n |
|
self.count += n |
|
self.avg = self.sum / self.count |
|
|
|
def all_reduce(self): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device) |
|
dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False) |
|
self.sum, self.count = total.tolist() |
|
self.avg = self.sum / self.count |
|
|
|
def __str__(self): |
|
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' |
|
return fmtstr.format(**self.__dict__) |
|
|
|
def summary(self): |
|
fmtstr = '' |
|
if self.summary_type is Summary.NONE: |
|
fmtstr = '' |
|
elif self.summary_type is Summary.AVERAGE: |
|
fmtstr = '{name} {avg:.3f}' |
|
elif self.summary_type is Summary.SUM: |
|
fmtstr = '{name} {sum:.3f}' |
|
elif self.summary_type is Summary.COUNT: |
|
fmtstr = '{name} {count:.3f}' |
|
else: |
|
raise ValueError('invalid summary type %r' % self.summary_type) |
|
|
|
return fmtstr.format(**self.__dict__) |
|
|