import time import requests from io import BytesIO from os import path from torch.utils.data import Dataset from PIL import Image class TestImageSetOnline(Dataset): """ Test Image set with hugging face CLIP preprocess interface Args: Dataset (torch.utils.data.Dataset): """ def __init__(self, processor, image_list, timeout_base=0.5, timeout_mul=2): """ Args: processor (CLIP preprocessor): process data to a CLIP digestable format image_list (pandas.DataFrame): pandas.DataFrame that contains image metadata timeout_base (float, optional): initial timeout parameter. Defaults to 0.5. timeout_mul (int, optional): multiplier on timeout every time reqeust fails. Defaults to 2. """ self.image_list = image_list self.processor = processor self.timeout_base = timeout_base self.timeout = self.timeout_base self.timeout_mul = timeout_mul def __getitem__(self, index): row = self.image_list[index] url = str(row['coco_url']) _id = str(row['id']) txt, img = None, None flag = True while flag: try: # Get images online response = requests.get(url) img = Image.open(BytesIO(response.content)) img_s = img.size if img.mode in ['L', 'CMYK', 'RGBA']: # L is grayscale, CMYK uses alternative color channels img = img.convert('RGB') # Preprocess image ret = self.processor(text=txt, images=img, return_tensor='pt') img = ret['pixel_values'][0] # If success, then there will be no need to run this again flag = False # Relief the timeout param if self.timeout > self.timeout_base: self.timeout /= self.timeout_mul except Exception as e: print(f"{_id} {url}: {str(e)}") if type(e) is KeyboardInterrupt: raise e time.sleep(self.timeout) # Tension the timeout param and turn into a new request self.timeout *= self.timeout_mul return _id, url, img, img_s def get(self, url): _id = url txt, img = None, None flag = True while flag: try: # Get images online response = requests.get(url) img = Image.open(BytesIO(response.content)) img_s = img.size if img.mode in ['L', 'CMYK', 'RGBA']: # L is grayscale, CMYK uses alternative color channels img = img.convert('RGB') # Preprocess image ret = self.processor(text=txt, images=img, return_tensor='pt') img = ret['pixel_values'][0] # If success, then there will be no need to run this again flag = False # Relief the timeout param if self.timeout > self.timeout_base: self.timeout /= self.timeout_mul except Exception as e: print(f"{_id} {url}: {str(e)}") if type(e) is KeyboardInterrupt: raise e time.sleep(self.timeout) # Tension the timeout param and turn into a new request self.timeout *= self.timeout_mul return _id, url, img, img_s def __len__(self,): return len(self.image_list) def __add__(self, other): self.image_list += other.image_list return self class TestImageSet(TestImageSetOnline): def __init__(self, droot, processor, image_list, timeout_base=0.5, timeout_mul=2): super().__init__(processor, image_list, timeout_base, timeout_mul) self.droot = droot def __getitem__(self, index): row = self.image_list[index] url = str(row['coco_url']) _id = '_'.join([url.split('/')[-2], str(row['id'])]) txt, img = None, None # Get images online img = Image.open(path.join(self.droot, url.split('http://images.cocodataset.org/')[1])) img_s = img.size if img.mode in ['L', 'CMYK', 'RGBA']: # L is grayscale, CMYK uses alternative color channels img = img.convert('RGB') # Preprocess image ret = self.processor(text=txt, images=img, return_tensor='pt') img = ret['pixel_values'][0] # If success, then there will be no need to run this again return _id, url, img, img_s