from utils.distributed import is_main_process, get_rank, get_world_size import io import json import re import numpy as np from os.path import join from tqdm import trange from PIL import Image from PIL import ImageFile from torchvision.transforms import PILToTensor ImageFile.LOAD_TRUNCATED_IMAGES = True Image.MAX_IMAGE_PIXELS = None def load_image_from_path(image_path, client): if image_path.startswith('s3') or image_path.startswith('p2'): value = client.Get(image_path) img_bytes = np.frombuffer(value, dtype=np.uint8) buff = io.BytesIO(img_bytes) image = Image.open(buff).convert('RGB') else: image = Image.open(image_path).convert('RGB') # PIL Image image = PILToTensor()(image).unsqueeze(0) # (1, C, H, W), torch.uint8 return image def pre_text(text, max_l=None, pre_text=True): if pre_text: text = re.sub(r"([,.'!?\"()*#:;~])", '', text.lower()) text = text.replace('-', ' ').replace('/', ' ').replace('', 'person') text = re.sub(r"\s{2,}", ' ', text) text = text.rstrip('\n').strip(' ') if max_l: # truncate words = text.split(' ') if len(words) > max_l: text = ' '.join(words[:max_l]) else: pass return text