Spaces:
Runtime error
Runtime error
import time | |
import requests | |
from io import BytesIO | |
from os import path | |
from torch.utils.data import Dataset | |
from PIL import Image | |
class TestImageSetOnline(Dataset): | |
""" Test Image set with hugging face CLIP preprocess interface | |
Args: | |
Dataset (torch.utils.data.Dataset): | |
""" | |
def __init__(self, processor, image_list, timeout_base=0.5, timeout_mul=2): | |
""" | |
Args: | |
processor (CLIP preprocessor): process data to a CLIP digestable format | |
image_list (pandas.DataFrame): pandas.DataFrame that contains image metadata | |
timeout_base (float, optional): initial timeout parameter. Defaults to 0.5. | |
timeout_mul (int, optional): multiplier on timeout every time reqeust fails. Defaults to 2. | |
""" | |
self.image_list = image_list | |
self.processor = processor | |
self.timeout_base = timeout_base | |
self.timeout = self.timeout_base | |
self.timeout_mul = timeout_mul | |
def __getitem__(self, index): | |
row = self.image_list[index] | |
url = str(row['coco_url']) | |
_id = str(row['id']) | |
txt, img = None, None | |
flag = True | |
while flag: | |
try: | |
# Get images online | |
response = requests.get(url) | |
img = Image.open(BytesIO(response.content)) | |
img_s = img.size | |
if img.mode in ['L', 'CMYK', 'RGBA']: | |
# L is grayscale, CMYK uses alternative color channels | |
img = img.convert('RGB') | |
# Preprocess image | |
ret = self.processor(text=txt, images=img, return_tensor='pt') | |
img = ret['pixel_values'][0] | |
# If success, then there will be no need to run this again | |
flag = False | |
# Relief the timeout param | |
if self.timeout > self.timeout_base: | |
self.timeout /= self.timeout_mul | |
except Exception as e: | |
print(f"{_id} {url}: {str(e)}") | |
if type(e) is KeyboardInterrupt: | |
raise e | |
time.sleep(self.timeout) | |
# Tension the timeout param and turn into a new request | |
self.timeout *= self.timeout_mul | |
return _id, url, img, img_s | |
def get(self, url): | |
_id = url | |
txt, img = None, None | |
flag = True | |
while flag: | |
try: | |
# Get images online | |
response = requests.get(url) | |
img = Image.open(BytesIO(response.content)) | |
img_s = img.size | |
if img.mode in ['L', 'CMYK', 'RGBA']: | |
# L is grayscale, CMYK uses alternative color channels | |
img = img.convert('RGB') | |
# Preprocess image | |
ret = self.processor(text=txt, images=img, return_tensor='pt') | |
img = ret['pixel_values'][0] | |
# If success, then there will be no need to run this again | |
flag = False | |
# Relief the timeout param | |
if self.timeout > self.timeout_base: | |
self.timeout /= self.timeout_mul | |
except Exception as e: | |
print(f"{_id} {url}: {str(e)}") | |
if type(e) is KeyboardInterrupt: | |
raise e | |
time.sleep(self.timeout) | |
# Tension the timeout param and turn into a new request | |
self.timeout *= self.timeout_mul | |
return _id, url, img, img_s | |
def __len__(self,): | |
return len(self.image_list) | |
def __add__(self, other): | |
self.image_list += other.image_list | |
return self | |
class TestImageSet(TestImageSetOnline): | |
def __init__(self, droot, processor, image_list, timeout_base=0.5, timeout_mul=2): | |
super().__init__(processor, image_list, timeout_base, timeout_mul) | |
self.droot = droot | |
def __getitem__(self, index): | |
row = self.image_list[index] | |
url = str(row['coco_url']) | |
_id = '_'.join([url.split('/')[-2], str(row['id'])]) | |
txt, img = None, None | |
# Get images online | |
img = Image.open(path.join(self.droot, | |
url.split('http://images.cocodataset.org/')[1])) | |
img_s = img.size | |
if img.mode in ['L', 'CMYK', 'RGBA']: | |
# L is grayscale, CMYK uses alternative color channels | |
img = img.convert('RGB') | |
# Preprocess image | |
ret = self.processor(text=txt, images=img, return_tensor='pt') | |
img = ret['pixel_values'][0] | |
# If success, then there will be no need to run this again | |
return _id, url, img, img_s | |