Spaces:
Build error
Build error
from enum import Enum | |
import subprocess | |
import sys | |
import shutil | |
import torch | |
import torch.distributed as dist | |
from torchvision.transforms import functional as F | |
from torchvision import transforms as T | |
from transformers import AutoFeatureExtractor | |
from PIL import Image, ImageDraw, ImageFont, ImageOps | |
import requests | |
from io import BytesIO | |
import random | |
def dump_git_status(out_file=sys.stdout, exclude_file_patterns=['*.ipynb', '*.th', '*.sh', '*.txt', '*.json']): | |
"""Logs git status to stdout.""" | |
subprocess.call('git rev-parse HEAD', shell=True, stdout=out_file) | |
subprocess.call('echo', shell=True, stdout=out_file) | |
exclude_string = '' | |
subprocess.call('git --no-pager diff -- . {}'.format(exclude_string), shell=True, stdout=out_file) | |
def get_image_from_url(url: str): | |
response = requests.get(url) | |
img = Image.open(BytesIO(response.content)) | |
img = img.resize((224, 224)) | |
img = img.convert('RGB') | |
return img | |
def truncate_caption(caption: str) -> str: | |
"""Truncate captions at periods and newlines.""" | |
caption = caption.strip('\n') | |
trunc_index = caption.find('\n') + 1 | |
if trunc_index <= 0: | |
trunc_index = caption.find('.') + 1 | |
if trunc_index > 0: | |
caption = caption[:trunc_index] | |
return caption | |
def pad_to_size(x, size=256): | |
delta_w = size - x.size[0] | |
delta_h = size - x.size[1] | |
padding = ( | |
delta_w // 2, | |
delta_h // 2, | |
delta_w - (delta_w // 2), | |
delta_h - (delta_h // 2), | |
) | |
new_im = ImageOps.expand(x, padding) | |
return new_im | |
class RandCropResize(object): | |
""" | |
Randomly crops, then randomly resizes, then randomly crops again, an image. Mirroring the augmentations from https://arxiv.org/abs/2102.12092 | |
""" | |
def __init__(self, target_size): | |
self.target_size = target_size | |
def __call__(self, img): | |
img = pad_to_size(img, self.target_size) | |
d_min = min(img.size) | |
img = T.RandomCrop(size=d_min)(img) | |
t_min = min(d_min, round(9 / 8 * self.target_size)) | |
t_max = min(d_min, round(12 / 8 * self.target_size)) | |
t = random.randint(t_min, t_max + 1) | |
img = T.Resize(t)(img) | |
if min(img.size) < 256: | |
img = T.Resize(256)(img) | |
return T.RandomCrop(size=self.target_size)(img) | |
class SquarePad(object): | |
"""Pads image to square. | |
From https://discuss.pytorch.org/t/how-to-resize-and-pad-in-a-torchvision-transforms-compose/71850/9 | |
""" | |
def __call__(self, image): | |
max_wh = max(image.size) | |
p_left, p_top = [(max_wh - s) // 2 for s in image.size] | |
p_right, p_bottom = [max_wh - (s+pad) for s, pad in zip(image.size, [p_left, p_top])] | |
padding = (p_left, p_top, p_right, p_bottom) | |
return F.pad(image, padding, 0, 'constant') | |
def create_image_of_text(text: str, width: int = 224, nrows: int = 2, color=(255, 255, 255), font=None) -> torch.Tensor: | |
"""Creates a (3, nrows * 14, width) image of text. | |
Returns: | |
cap_img: (3, 14 * nrows, width) image of wrapped text. | |
""" | |
height = 12 | |
padding = 5 | |
effective_width = width - 2 * padding | |
# Create a black image to draw text on. | |
cap_img = Image.new('RGB', (effective_width * nrows, height), color = (0, 0, 0)) | |
draw = ImageDraw.Draw(cap_img) | |
draw.text((0, 0), text, color, font=font or ImageFont.load_default()) | |
cap_img = F.convert_image_dtype(F.pil_to_tensor(cap_img), torch.float32) # (3, height, W * nrows) | |
cap_img = torch.split(cap_img, effective_width, dim=-1) # List of nrow elements of shape (3, height, W) | |
cap_img = torch.cat(cap_img, dim=1) # (3, height * nrows, W) | |
# Add zero padding. | |
cap_img = torch.nn.functional.pad(cap_img, [padding, padding, 0, padding]) | |
return cap_img | |
def get_feature_extractor_for_model(model_name: str, image_size: int = 224, train: bool = True): | |
print(f'Using HuggingFace AutoFeatureExtractor for {model_name}.') | |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) | |
return feature_extractor | |
def get_pixel_values_for_model(feature_extractor, img): | |
pixel_values = feature_extractor( | |
img.convert('RGB'), | |
return_tensors="pt").pixel_values[0, ...] # (3, H, W) | |
return pixel_values | |
def save_checkpoint(state, is_best, filename='checkpoint'): | |
torch.save(state, filename + '.pth.tar') | |
if is_best: | |
shutil.copyfile(filename + '.pth.tar', filename + '_best.pth.tar') | |
def accuracy(output, target, padding, topk=(1,)): | |
"""Computes the accuracy over the k top predictions for the specified values of k""" | |
with torch.no_grad(): | |
maxk = max(topk) | |
if output.shape[-1] < maxk: | |
print(f"[WARNING] Less than {maxk} predictions available. Using {output.shape[-1]} for topk.") | |
maxk = min(maxk, output.shape[-1]) | |
batch_size = target.size(0) | |
# Take topk along the last dimension. | |
_, pred = output.topk(maxk, -1, True, True) # (N, T, topk) | |
mask = (target != padding).type(target.dtype) | |
target_expand = target[..., None].expand_as(pred) | |
correct = pred.eq(target_expand) | |
correct = correct * mask[..., None].expand_as(correct) | |
res = [] | |
for k in topk: | |
correct_k = correct[..., :k].reshape(-1).float().sum(0, keepdim=True) | |
res.append(correct_k.mul_(100.0 / mask.sum())) | |
return res | |
def get_params_count(model, max_name_len: int = 60): | |
params = [(name[:max_name_len], p.numel(), str(tuple(p.shape)), p.requires_grad) for name, p in model.named_parameters()] | |
total_trainable_params = sum([x[1] for x in params if x[-1]]) | |
total_nontrainable_params = sum([x[1] for x in params if not x[-1]]) | |
return params, total_trainable_params, total_nontrainable_params | |
def get_params_count_str(model, max_name_len: int = 60): | |
padding = 70 # Hardcoded depending on desired amount of padding and separators. | |
params, total_trainable_params, total_nontrainable_params = get_params_count(model, max_name_len) | |
param_counts_text = '' | |
param_counts_text += '=' * (max_name_len + padding) + '\n' | |
param_counts_text += f'| {"Module":<{max_name_len}} | {"Trainable":<10} | {"Shape":>15} | {"Param Count":>12} |\n' | |
param_counts_text += '-' * (max_name_len + padding) + '\n' | |
for name, param_count, shape, trainable in params: | |
param_counts_text += f'| {name:<{max_name_len}} | {"True" if trainable else "False":<10} | {shape:>15} | {param_count:>12,} |\n' | |
param_counts_text += '-' * (max_name_len + padding) + '\n' | |
param_counts_text += f'| {"Total trainable params":<{max_name_len}} | {"":<10} | {"":<15} | {total_trainable_params:>12,} |\n' | |
param_counts_text += f'| {"Total non-trainable params":<{max_name_len}} | {"":<10} | {"":<15} | {total_nontrainable_params:>12,} |\n' | |
param_counts_text += '=' * (max_name_len + padding) + '\n' | |
return param_counts_text | |
class Summary(Enum): | |
NONE = 0 | |
AVERAGE = 1 | |
SUM = 2 | |
COUNT = 3 | |
class ProgressMeter(object): | |
def __init__(self, num_batches, meters, prefix=""): | |
self.batch_fmtstr = self._get_batch_fmtstr(num_batches) | |
self.meters = meters | |
self.prefix = prefix | |
def display(self, batch): | |
entries = [self.prefix + self.batch_fmtstr.format(batch)] | |
entries += [str(meter) for meter in self.meters] | |
print('\t'.join(entries)) | |
def display_summary(self): | |
entries = [" *"] | |
entries += [meter.summary() for meter in self.meters] | |
print(' '.join(entries)) | |
def _get_batch_fmtstr(self, num_batches): | |
num_digits = len(str(num_batches // 1)) | |
fmt = '{:' + str(num_digits) + 'd}' | |
return '[' + fmt + '/' + fmt.format(num_batches) + ']' | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE): | |
self.name = name | |
self.fmt = fmt | |
self.summary_type = summary_type | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
def all_reduce(self): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device) | |
dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False) | |
self.sum, self.count = total.tolist() | |
self.avg = self.sum / self.count | |
def __str__(self): | |
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' | |
return fmtstr.format(**self.__dict__) | |
def summary(self): | |
fmtstr = '' | |
if self.summary_type is Summary.NONE: | |
fmtstr = '' | |
elif self.summary_type is Summary.AVERAGE: | |
fmtstr = '{name} {avg:.3f}' | |
elif self.summary_type is Summary.SUM: | |
fmtstr = '{name} {sum:.3f}' | |
elif self.summary_type is Summary.COUNT: | |
fmtstr = '{name} {count:.3f}' | |
else: | |
raise ValueError('invalid summary type %r' % self.summary_type) | |
return fmtstr.format(**self.__dict__) | |