Spaces:
Runtime error
Runtime error
# Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, visit | |
# https://github.com/NVlabs/prismer/blob/main/LICENSE | |
import os | |
import re | |
import json | |
import torch | |
import PIL.Image as Image | |
import numpy as np | |
import torchvision.transforms as transforms | |
import torchvision.transforms.functional as transforms_f | |
from dataset.randaugment import RandAugment | |
COCO_FEATURES = torch.load('dataset/coco_features.pt')['features'] | |
ADE_FEATURES = torch.load('dataset/ade_features.pt')['features'] | |
DETECTION_FEATURES = torch.load('dataset/detection_features.pt')['features'] | |
BACKGROUND_FEATURES = torch.load('dataset/background_features.pt') | |
class Transform: | |
def __init__(self, resize_resolution=384, scale_size=[0.5, 1.0], train=False): | |
self.resize_size = [resize_resolution, resize_resolution] | |
self.scale_size = scale_size | |
self.train = train | |
self.randaugment = RandAugment(2, 5) | |
def __call__(self, image, labels): | |
if self.train: | |
# random resize crop | |
i, j, h, w = transforms.RandomResizedCrop.get_params(img=image, scale=self.scale_size, ratio=[3. / 4, 4. / 3]) | |
image = transforms_f.crop(image, i, j, h, w) | |
if labels is not None: | |
for exp in labels: | |
labels[exp] = transforms_f.crop(labels[exp], i, j, h, w) | |
# resize to the defined shape | |
image = transforms_f.resize(image, self.resize_size, transforms_f.InterpolationMode.BICUBIC) | |
if labels is not None: | |
for exp in labels: | |
labels[exp] = transforms_f.resize(labels[exp], [224, 224], transforms_f.InterpolationMode.NEAREST) | |
if self.train: | |
# random flipping | |
if torch.rand(1) > 0.5: | |
image = transforms_f.hflip(image) | |
if labels is not None: | |
for exp in labels: | |
labels[exp] = transforms_f.hflip(labels[exp]) | |
# random augmentation | |
image, labels = self.randaugment(image, labels) | |
# transform to tensor | |
image = transforms_f.to_tensor(image) | |
if labels is not None: | |
for exp in labels: | |
if exp in ['depth', 'normal', 'edge']: | |
labels[exp] = transforms_f.to_tensor(labels[exp]) | |
else: | |
labels[exp] = (transforms_f.to_tensor(labels[exp]) * 255).long() | |
# apply normalisation: | |
image = transforms_f.normalize(image, mean=[0.48145466, 0.4578275, 0.40821073], | |
std=[0.26862954, 0.26130258, 0.27577711]) | |
if labels is not None: | |
return {'rgb': image, **labels} | |
else: | |
return{'rgb': image} | |
def get_expert_labels(data_path, label_path, image_path, dataset, experts): | |
image_full_path = os.path.join(data_path, dataset, image_path) | |
image = Image.open(image_full_path).convert('RGB') | |
if experts != 'none': | |
labels = {} | |
labels_info = {} | |
ps = image_path.split('.')[-1] | |
for exp in experts: | |
if exp in ['seg_coco', 'seg_ade', 'edge', 'depth']: | |
label_full_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.png')) | |
if os.stat(label_full_path).st_size > 0: | |
labels[exp] = Image.open(label_full_path).convert('L') | |
else: | |
labels[exp] = Image.fromarray(np.zeros([image.size[1], image.size[0]])).convert('L') | |
elif exp == 'normal': | |
label_full_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.png')) | |
if os.stat(label_full_path).st_size > 0: | |
labels[exp] = Image.open(label_full_path).convert('RGB') | |
else: | |
labels[exp] = Image.fromarray(np.zeros([image.size[1], image.size[0], 3])).convert('RGB') | |
elif exp == 'obj_detection': | |
label_full_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.png')) | |
if os.stat(label_full_path).st_size > 0: | |
labels[exp] = Image.open(label_full_path).convert('L') | |
else: | |
labels[exp] = Image.fromarray(255 * np.ones([image.size[1], image.size[0]])).convert('L') | |
label_info_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.json')) | |
labels_info[exp] = json.load(open(label_info_path, 'r')) | |
elif exp == 'ocr_detection': | |
label_full_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.png')) | |
label_info_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.pt')) | |
if os.path.exists(label_info_path): | |
labels[exp] = Image.open(label_full_path).convert('L') | |
labels_info[exp] = torch.load(label_info_path) | |
else: | |
labels[exp] = Image.fromarray(255 * np.ones([image.size[1], image.size[0]])).convert('L') | |
labels_info[exp] = None | |
else: | |
labels, labels_info = None, None | |
return image, labels, labels_info | |
def post_label_process(inputs, labels_info): | |
eps = 1e-6 | |
for exp in inputs: | |
if exp in ['depth', 'normal', 'edge']: # remap to -1 to 1 range | |
inputs[exp] = 2 * (inputs[exp] - inputs[exp].min()) / (inputs[exp].max() - inputs[exp].min() + eps) - 1 | |
elif exp == 'seg_coco': # in-paint with CLIP features | |
text_emb = torch.empty([64, *inputs[exp].shape[1:]]) | |
for l in inputs[exp].unique(): | |
if l == 255: | |
text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1) | |
else: | |
text_emb[:, (inputs[exp][0] == l)] = COCO_FEATURES[l].unsqueeze(-1) | |
inputs[exp] = text_emb | |
elif exp == 'seg_ade': # in-paint with CLIP features | |
text_emb = torch.empty([64, *inputs[exp].shape[1:]]) | |
for l in inputs[exp].unique(): | |
if l == 255: | |
text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1) | |
else: | |
text_emb[:, (inputs[exp][0] == l)] = ADE_FEATURES[l].unsqueeze(-1) | |
inputs[exp] = text_emb | |
elif exp == 'obj_detection': # in-paint with CLIP features | |
text_emb = torch.empty([64, *inputs[exp].shape[1:]]) | |
label_map = labels_info[exp] | |
for l in inputs[exp].unique(): | |
if l == 255: | |
text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1) | |
else: | |
text_emb[:, (inputs[exp][0] == l)] = DETECTION_FEATURES[label_map[str(l.item())]].unsqueeze(-1) | |
inputs[exp] = {'label': text_emb, 'instance': inputs[exp]} | |
elif exp == 'ocr_detection': # in-paint with CLIP features | |
text_emb = torch.empty([64, *inputs[exp].shape[1:]]) | |
label_map = labels_info[exp] | |
for l in inputs[exp].unique(): | |
if l == 255: | |
text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1) | |
else: | |
text_emb[:, (inputs[exp][0] == l)] = label_map[l.item()]['features'].unsqueeze(-1) | |
inputs[exp] = text_emb | |
return inputs | |
def pre_caption(caption, max_words=50): | |
caption = re.sub(r"([.!\"()*#:;~])", ' ', caption.capitalize()) # remove special characters | |
caption = re.sub(r"\s{2,}", ' ', caption) # remove two white spaces | |
caption = caption.rstrip('\n') # remove \num_ans_per_q symbol | |
caption = caption.strip(' ') # remove leading and trailing white spaces | |
# truncate caption to the max words | |
caption_words = caption.split(' ') | |
if len(caption_words) > max_words: | |
caption = ' '.join(caption_words[:max_words]) | |
return caption | |
def pre_question(question, max_words=50): | |
question = re.sub(r"([.!\"()*#:;~])", ' ', question.capitalize()) # remove special characters | |
question = question.strip() | |
# truncate question | |
question_words = question.split(' ') | |
if len(question_words) > max_words: | |
question = ' '.join(question_words[:max_words]) | |
if question[-1] != '?': | |
question += '?' | |
return question | |