fromage / fromage /utils.py
jykoh's picture
Fix truncation
8041bf9
from enum import Enum
import subprocess
import sys
import shutil
import torch
import torch.distributed as dist
from torchvision.transforms import functional as F
from torchvision import transforms as T
from transformers import AutoFeatureExtractor
from PIL import Image, ImageDraw, ImageFont, ImageOps
import requests
from io import BytesIO
import random
def dump_git_status(out_file=sys.stdout, exclude_file_patterns=['*.ipynb', '*.th', '*.sh', '*.txt', '*.json']):
"""Logs git status to stdout."""
subprocess.call('git rev-parse HEAD', shell=True, stdout=out_file)
subprocess.call('echo', shell=True, stdout=out_file)
exclude_string = ''
subprocess.call('git --no-pager diff -- . {}'.format(exclude_string), shell=True, stdout=out_file)
def get_image_from_url(url: str):
response = requests.get(url)
img = Image.open(BytesIO(response.content))
img = img.resize((224, 224))
img = img.convert('RGB')
return img
def truncate_caption(caption: str) -> str:
"""Truncate captions at periods and newlines."""
caption = caption.strip('\n')
trunc_index = caption.find('\n') + 1
if trunc_index <= 0:
trunc_index = caption.find('.') + 1
if trunc_index > 0:
caption = caption[:trunc_index]
return caption
def pad_to_size(x, size=256):
delta_w = size - x.size[0]
delta_h = size - x.size[1]
padding = (
delta_w // 2,
delta_h // 2,
delta_w - (delta_w // 2),
delta_h - (delta_h // 2),
)
new_im = ImageOps.expand(x, padding)
return new_im
class RandCropResize(object):
"""
Randomly crops, then randomly resizes, then randomly crops again, an image. Mirroring the augmentations from https://arxiv.org/abs/2102.12092
"""
def __init__(self, target_size):
self.target_size = target_size
def __call__(self, img):
img = pad_to_size(img, self.target_size)
d_min = min(img.size)
img = T.RandomCrop(size=d_min)(img)
t_min = min(d_min, round(9 / 8 * self.target_size))
t_max = min(d_min, round(12 / 8 * self.target_size))
t = random.randint(t_min, t_max + 1)
img = T.Resize(t)(img)
if min(img.size) < 256:
img = T.Resize(256)(img)
return T.RandomCrop(size=self.target_size)(img)
class SquarePad(object):
"""Pads image to square.
From https://discuss.pytorch.org/t/how-to-resize-and-pad-in-a-torchvision-transforms-compose/71850/9
"""
def __call__(self, image):
max_wh = max(image.size)
p_left, p_top = [(max_wh - s) // 2 for s in image.size]
p_right, p_bottom = [max_wh - (s+pad) for s, pad in zip(image.size, [p_left, p_top])]
padding = (p_left, p_top, p_right, p_bottom)
return F.pad(image, padding, 0, 'constant')
def create_image_of_text(text: str, width: int = 224, nrows: int = 2, color=(255, 255, 255), font=None) -> torch.Tensor:
"""Creates a (3, nrows * 14, width) image of text.
Returns:
cap_img: (3, 14 * nrows, width) image of wrapped text.
"""
height = 12
padding = 5
effective_width = width - 2 * padding
# Create a black image to draw text on.
cap_img = Image.new('RGB', (effective_width * nrows, height), color = (0, 0, 0))
draw = ImageDraw.Draw(cap_img)
draw.text((0, 0), text, color, font=font or ImageFont.load_default())
cap_img = F.convert_image_dtype(F.pil_to_tensor(cap_img), torch.float32) # (3, height, W * nrows)
cap_img = torch.split(cap_img, effective_width, dim=-1) # List of nrow elements of shape (3, height, W)
cap_img = torch.cat(cap_img, dim=1) # (3, height * nrows, W)
# Add zero padding.
cap_img = torch.nn.functional.pad(cap_img, [padding, padding, 0, padding])
return cap_img
def get_feature_extractor_for_model(model_name: str, image_size: int = 224, train: bool = True):
print(f'Using HuggingFace AutoFeatureExtractor for {model_name}.')
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
return feature_extractor
def get_pixel_values_for_model(feature_extractor, img):
pixel_values = feature_extractor(
img.convert('RGB'),
return_tensors="pt").pixel_values[0, ...] # (3, H, W)
return pixel_values
def save_checkpoint(state, is_best, filename='checkpoint'):
torch.save(state, filename + '.pth.tar')
if is_best:
shutil.copyfile(filename + '.pth.tar', filename + '_best.pth.tar')
def accuracy(output, target, padding, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
if output.shape[-1] < maxk:
print(f"[WARNING] Less than {maxk} predictions available. Using {output.shape[-1]} for topk.")
maxk = min(maxk, output.shape[-1])
batch_size = target.size(0)
# Take topk along the last dimension.
_, pred = output.topk(maxk, -1, True, True) # (N, T, topk)
mask = (target != padding).type(target.dtype)
target_expand = target[..., None].expand_as(pred)
correct = pred.eq(target_expand)
correct = correct * mask[..., None].expand_as(correct)
res = []
for k in topk:
correct_k = correct[..., :k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / mask.sum()))
return res
def get_params_count(model, max_name_len: int = 60):
params = [(name[:max_name_len], p.numel(), str(tuple(p.shape)), p.requires_grad) for name, p in model.named_parameters()]
total_trainable_params = sum([x[1] for x in params if x[-1]])
total_nontrainable_params = sum([x[1] for x in params if not x[-1]])
return params, total_trainable_params, total_nontrainable_params
def get_params_count_str(model, max_name_len: int = 60):
padding = 70 # Hardcoded depending on desired amount of padding and separators.
params, total_trainable_params, total_nontrainable_params = get_params_count(model, max_name_len)
param_counts_text = ''
param_counts_text += '=' * (max_name_len + padding) + '\n'
param_counts_text += f'| {"Module":<{max_name_len}} | {"Trainable":<10} | {"Shape":>15} | {"Param Count":>12} |\n'
param_counts_text += '-' * (max_name_len + padding) + '\n'
for name, param_count, shape, trainable in params:
param_counts_text += f'| {name:<{max_name_len}} | {"True" if trainable else "False":<10} | {shape:>15} | {param_count:>12,} |\n'
param_counts_text += '-' * (max_name_len + padding) + '\n'
param_counts_text += f'| {"Total trainable params":<{max_name_len}} | {"":<10} | {"":<15} | {total_trainable_params:>12,} |\n'
param_counts_text += f'| {"Total non-trainable params":<{max_name_len}} | {"":<10} | {"":<15} | {total_nontrainable_params:>12,} |\n'
param_counts_text += '=' * (max_name_len + padding) + '\n'
return param_counts_text
class Summary(Enum):
NONE = 0
AVERAGE = 1
SUM = 2
COUNT = 3
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
def display_summary(self):
entries = [" *"]
entries += [meter.summary() for meter in self.meters]
print(' '.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
self.name = name
self.fmt = fmt
self.summary_type = summary_type
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def all_reduce(self):
device = "cuda" if torch.cuda.is_available() else "cpu"
total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
self.sum, self.count = total.tolist()
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def summary(self):
fmtstr = ''
if self.summary_type is Summary.NONE:
fmtstr = ''
elif self.summary_type is Summary.AVERAGE:
fmtstr = '{name} {avg:.3f}'
elif self.summary_type is Summary.SUM:
fmtstr = '{name} {sum:.3f}'
elif self.summary_type is Summary.COUNT:
fmtstr = '{name} {count:.3f}'
else:
raise ValueError('invalid summary type %r' % self.summary_type)
return fmtstr.format(**self.__dict__)