Spaces:
Runtime error
Runtime error
File size: 4,750 Bytes
3f1124e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import time
import requests
from io import BytesIO
from os import path
from torch.utils.data import Dataset
from PIL import Image
class TestImageSetOnline(Dataset):
""" Test Image set with hugging face CLIP preprocess interface
Args:
Dataset (torch.utils.data.Dataset):
"""
def __init__(self, processor, image_list, timeout_base=0.5, timeout_mul=2):
"""
Args:
processor (CLIP preprocessor): process data to a CLIP digestable format
image_list (pandas.DataFrame): pandas.DataFrame that contains image metadata
timeout_base (float, optional): initial timeout parameter. Defaults to 0.5.
timeout_mul (int, optional): multiplier on timeout every time reqeust fails. Defaults to 2.
"""
self.image_list = image_list
self.processor = processor
self.timeout_base = timeout_base
self.timeout = self.timeout_base
self.timeout_mul = timeout_mul
def __getitem__(self, index):
row = self.image_list[index]
url = str(row['coco_url'])
_id = str(row['id'])
txt, img = None, None
flag = True
while flag:
try:
# Get images online
response = requests.get(url)
img = Image.open(BytesIO(response.content))
img_s = img.size
if img.mode in ['L', 'CMYK', 'RGBA']:
# L is grayscale, CMYK uses alternative color channels
img = img.convert('RGB')
# Preprocess image
ret = self.processor(text=txt, images=img, return_tensor='pt')
img = ret['pixel_values'][0]
# If success, then there will be no need to run this again
flag = False
# Relief the timeout param
if self.timeout > self.timeout_base:
self.timeout /= self.timeout_mul
except Exception as e:
print(f"{_id} {url}: {str(e)}")
if type(e) is KeyboardInterrupt:
raise e
time.sleep(self.timeout)
# Tension the timeout param and turn into a new request
self.timeout *= self.timeout_mul
return _id, url, img, img_s
def get(self, url):
_id = url
txt, img = None, None
flag = True
while flag:
try:
# Get images online
response = requests.get(url)
img = Image.open(BytesIO(response.content))
img_s = img.size
if img.mode in ['L', 'CMYK', 'RGBA']:
# L is grayscale, CMYK uses alternative color channels
img = img.convert('RGB')
# Preprocess image
ret = self.processor(text=txt, images=img, return_tensor='pt')
img = ret['pixel_values'][0]
# If success, then there will be no need to run this again
flag = False
# Relief the timeout param
if self.timeout > self.timeout_base:
self.timeout /= self.timeout_mul
except Exception as e:
print(f"{_id} {url}: {str(e)}")
if type(e) is KeyboardInterrupt:
raise e
time.sleep(self.timeout)
# Tension the timeout param and turn into a new request
self.timeout *= self.timeout_mul
return _id, url, img, img_s
def __len__(self,):
return len(self.image_list)
def __add__(self, other):
self.image_list += other.image_list
return self
class TestImageSet(TestImageSetOnline):
def __init__(self, droot, processor, image_list, timeout_base=0.5, timeout_mul=2):
super().__init__(processor, image_list, timeout_base, timeout_mul)
self.droot = droot
def __getitem__(self, index):
row = self.image_list[index]
url = str(row['coco_url'])
_id = '_'.join([url.split('/')[-2], str(row['id'])])
txt, img = None, None
# Get images online
img = Image.open(path.join(self.droot,
url.split('http://images.cocodataset.org/')[1]))
img_s = img.size
if img.mode in ['L', 'CMYK', 'RGBA']:
# L is grayscale, CMYK uses alternative color channels
img = img.convert('RGB')
# Preprocess image
ret = self.processor(text=txt, images=img, return_tensor='pt')
img = ret['pixel_values'][0]
# If success, then there will be no need to run this again
return _id, url, img, img_s
|