Spaces:
Runtime error
Runtime error
from utils.distributed import is_main_process, get_rank, get_world_size | |
import io | |
import json | |
import re | |
import numpy as np | |
from os.path import join | |
from tqdm import trange | |
from PIL import Image | |
from PIL import ImageFile | |
from torchvision.transforms import PILToTensor | |
ImageFile.LOAD_TRUNCATED_IMAGES = True | |
Image.MAX_IMAGE_PIXELS = None | |
def load_image_from_path(image_path, client): | |
if image_path.startswith('s3') or image_path.startswith('p2'): | |
value = client.Get(image_path) | |
img_bytes = np.frombuffer(value, dtype=np.uint8) | |
buff = io.BytesIO(img_bytes) | |
image = Image.open(buff).convert('RGB') | |
else: | |
image = Image.open(image_path).convert('RGB') # PIL Image | |
image = PILToTensor()(image).unsqueeze(0) # (1, C, H, W), torch.uint8 | |
return image | |
def pre_text(text, max_l=None, pre_text=True): | |
if pre_text: | |
text = re.sub(r"([,.'!?\"()*#:;~])", '', text.lower()) | |
text = text.replace('-', ' ').replace('/', ' ').replace('<person>', 'person') | |
text = re.sub(r"\s{2,}", ' ', text) | |
text = text.rstrip('\n').strip(' ') | |
if max_l: # truncate | |
words = text.split(' ') | |
if len(words) > max_l: | |
text = ' '.join(words[:max_l]) | |
else: | |
pass | |
return text | |