Spaces:
Runtime error
Runtime error
import torch | |
from torchvision import transforms | |
from PIL import Image | |
import random | |
from data_utils.randaugment import RandomAugment | |
from .builder import PROCESSORS | |
class CaptionProcessor: | |
def __init__(self, image_size=224, min_scale = 0.5, randaug=False): | |
self.image_size = image_size | |
self.min_scale = min_scale | |
if randaug: | |
self.image_transform = transforms.Compose([ | |
transforms.RandomResizedCrop(image_size,scale=(min_scale, 1.0), interpolation=Image.BICUBIC), | |
transforms.RandomHorizontalFlip(), | |
RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness', | |
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), | |
transforms.ToTensor(), | |
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), | |
]) | |
else: | |
self.image_transform = transforms.Compose([ | |
transforms.RandomResizedCrop(image_size,scale=(min_scale, 1.0), interpolation=Image.BICUBIC), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), | |
]) | |
self.text_transform = None | |
def __call__(self, image, text): | |
assert image or text | |
if image: | |
image_input = self.image_transform(image) | |
else: | |
image_input = None | |
if text: | |
if isinstance(text["prompt"], list): | |
prompt = random.choice(text["prompt"]) | |
else: | |
prompt = text["prompt"] | |
text_input = dict( | |
prompt=prompt, | |
completion=text["text"], | |
) | |
else: | |
text_input = None | |
return image_input, text_input |