Spaces:
Runtime error
Runtime error
''' | |
LinCIR | |
Copyright (c) 2023-present NAVER Corp. | |
CC BY-NC-4.0 (https://creativecommons.org/licenses/by-nc/4.0/) | |
''' | |
import os | |
import functools | |
import glob | |
import random | |
import json | |
from pathlib import Path | |
from typing import List, Optional, Union, Dict, Literal | |
import PIL | |
import PIL.Image | |
import torch | |
from torch.utils.data import Dataset | |
import webdataset as wds | |
import spacy | |
import numpy as np | |
import sng_parser | |
import datasets | |
def extract_keywords(spacy_nlp, caption): | |
candidates = [] | |
nlp_caption = caption | |
doc = spacy_nlp(nlp_caption) | |
tmp = '' | |
for word in doc: | |
if word.pos_ == 'ADJ': | |
if tmp == '': | |
tmp += word.text | |
else: | |
tmp += ' ' + word.text | |
elif word.pos_ == 'NOUN' or word.pos_ == 'PROPN': | |
if tmp == '': | |
tmp += word.text | |
else: | |
tmp += ' ' + word.text | |
else: | |
if tmp != '': | |
candidates.append(tmp) | |
tmp = '' | |
if tmp != '': | |
candidates.append(tmp) | |
candidates = list(set(candidates)) | |
return candidates | |
def extract_keywords_spacy(spacy_nlp, caption): | |
sequences = [] | |
current_sequence = [] | |
doc = spacy_nlp(caption) | |
for token in doc: | |
# Check if the token is a noun, proper noun, or adjective | |
if token.pos_ in ['NOUN', 'PROPN', 'ADJ', 'DET']: | |
current_sequence.append(token.text) | |
else: | |
# If we encounter a token that's not one of the desired POS and current_sequence is not empty | |
if current_sequence: | |
sequences.append(" ".join(current_sequence)) | |
current_sequence = [] | |
# Adding any remaining sequence after the loop | |
if current_sequence: | |
sequences.append(" ".join(current_sequence)) | |
return sequences | |
def extract_sng(caption): | |
graph = sng_parser.parse(caption) | |
entities = [x['head'] for i, x in enumerate(graph['entities'])] | |
relations = [{'subject': entities[x['subject']], 'object': entities[x['object']], 'relation': x['relation']} for x in graph['relations']] | |
return entities, relations | |
def clean_caption(caption, tokenizer): | |
if caption is None: | |
caption = '' | |
if '<PERSON>' in caption: # to handle with GCC12M | |
caption = caption.replace('<PERSON>', 'person') | |
caption = caption.lower().replace('$', '').strip() | |
tokens = tokenizer.encode(caption, padding='longest', return_tensors='pt') | |
if tokens.shape[1] > 77: | |
caption = tokenizer.batch_decode(tokens[:,1:76])[0] | |
return caption | |
def preprocess_precomputed_base(sample, spacy_nlp, keywords_list, tokenizer): | |
''' | |
'image_feature.npy','json' | |
''' | |
image_feature, image_feature_giga, meta = sample | |
caption = clean_caption(meta['source_caption'], tokenizer) | |
keywords = [''] | |
try: | |
keywords = extract_keywords_spacy(spacy_nlp, caption) | |
except Exception as e: | |
#print(e) | |
pass | |
# for keywords | |
indicator = 1 | |
replaced_caption = caption | |
for keyword in keywords: | |
if keyword != '' and keyword in caption: | |
replaced_caption = replaced_caption.replace(keyword, '[$]') | |
else: | |
tmp_keywords = caption.split(' ') | |
if len(tmp_keywords) > 0: | |
selected_keywords = random.sample(tmp_keywords, k=min(int(len(tmp_keywords) * 1.0), 1)) | |
for selected_keyword in selected_keywords: | |
replaced_caption = replaced_caption.replace(selected_keyword, '[$]') | |
else: | |
replaced_caption = f'a photo of [$] that {caption}' | |
indicator = 0 | |
break | |
token_dict = tokenizer(text=caption, return_tensors='pt', padding='max_length', truncation=True) | |
tokens, attention_mask = token_dict['input_ids'][0], token_dict['attention_mask'][0] | |
replaced_token_dict = tokenizer(text=replaced_caption, return_tensors='pt', padding='max_length', truncation=True) | |
replaced_tokens, replaced_attention_mask = replaced_token_dict['input_ids'][0], replaced_token_dict['attention_mask'][0] | |
replaced_tokens = torch.where(replaced_tokens == 49408, | |
torch.ones_like(replaced_tokens) * 259, | |
replaced_tokens) | |
if 259 not in replaced_tokens: | |
replaced_caption = 'a photo of [$]' | |
replaced_token_dict = tokenizer(text=replaced_caption, return_tensors='pt', padding='max_length', truncation=True) | |
replaced_tokens, replaced_attention_mask = replaced_token_dict['input_ids'][0], replaced_token_dict['attention_mask'][0] | |
replaced_tokens = torch.where(replaced_tokens == 49408, | |
torch.ones_like(replaced_tokens) * 259, | |
replaced_tokens) | |
indicator = 0 | |
new_sample = [tokens, replaced_tokens, indicator] | |
return tuple(new_sample) | |
class CaptionDataset(Dataset): | |
def __init__(self, captions, tokenizer, spacy_nlp): | |
self.captions = captions | |
self.tokenizer = tokenizer | |
self.spacy_nlp = spacy_nlp | |
def __len__(self): | |
return len(self.captions) | |
def __getitem__(self, idx): | |
caption = self.captions[idx] | |
caption = clean_caption(caption, self.tokenizer) | |
keywords = [""] | |
try: | |
keywords = extract_keywords_spacy(self.spacy_nlp, caption) | |
except Exception as e: | |
#print(e) | |
pass | |
# for keywords | |
indicator = 1 | |
replaced_caption = caption | |
if len(keywords) == 0: | |
keywords = [""] | |
for keyword in keywords: | |
if keyword != '' and keyword in caption: | |
replaced_caption = replaced_caption.replace(keyword, '[$]') | |
else: | |
tmp_keywords = caption.split(' ') | |
if len(tmp_keywords) > 0: | |
selected_keywords = random.sample(tmp_keywords, k=min(int(len(tmp_keywords) * 1.0), 1)) | |
for selected_keyword in selected_keywords: | |
replaced_caption = replaced_caption.replace(selected_keyword, '[$]') | |
else: | |
replaced_caption = f'a photo of [$] that {caption}' | |
indicator = 0 | |
break | |
token_dict = self.tokenizer(text=caption, return_tensors='pt', padding='max_length', truncation=True) | |
tokens, attention_mask = token_dict['input_ids'][0], token_dict['attention_mask'][0] | |
replaced_token_dict = self.tokenizer(text=replaced_caption, return_tensors='pt', padding='max_length', truncation=True) | |
replaced_tokens, replaced_attention_mask = replaced_token_dict['input_ids'][0], replaced_token_dict['attention_mask'][0] | |
replaced_tokens = torch.where(replaced_tokens == 49408, | |
torch.ones_like(replaced_tokens) * 259, | |
replaced_tokens) | |
if 259 not in replaced_tokens: | |
replaced_caption = 'a photo of [$]' | |
replaced_token_dict = self.tokenizer(text=replaced_caption, return_tensors='pt', padding='max_length', truncation=True) | |
replaced_tokens, replaced_attention_mask = replaced_token_dict['input_ids'][0], replaced_token_dict['attention_mask'][0] | |
replaced_tokens = torch.where(replaced_tokens == 49408, | |
torch.ones_like(replaced_tokens) * 259, | |
replaced_tokens) | |
indicator = 0 | |
return tokens, replaced_tokens, indicator | |
def build_loader(args, tokenizer, accelerator): | |
data_names = {'dataset1': 'dangne/gcc_caption_only', | |
'dataset2': 'FredZhang7/stable-diffusion-prompts-2.47M', | |
'dataset3': 'Geonmo/midjourney-prompts-only', | |
} | |
for k, v in data_names.items(): | |
if not os.path.exists(os.path.join('./datasets', k)): | |
if accelerator.is_main_process: | |
print('Downloading captions is required') | |
db = datasets.load_dataset(v, cache_dir=os.path.join('./datasets', k)) | |
captions = [] | |
for k, v in data_names.items(): | |
db = datasets.load_dataset(v, cache_dir=os.path.join('./datasets', k)) | |
captions += db['train']['text'] | |
dataset = CaptionDataset(captions, tokenizer, spacy.load('en_core_web_sm')) | |
data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers, drop_last=True, shuffle=True) | |
return data_loader | |
class FashionIQDataset(Dataset): | |
""" | |
Copy-paste from https://github.com/miccunifi/SEARLE/blob/main/src/datasets.py | |
FashionIQ dataset class for PyTorch. | |
The dataset can be used in 'relative' or 'classic' mode: | |
- In 'classic' mode the dataset yield :a dict with keys ['image', 'image_name'] | |
- In 'relative' mode the dataset yield dict with keys: | |
- ['reference_image', 'reference_name', 'target_image', 'target_name', 'relative_captions'] when | |
split in ['train', 'val'] | |
- ['reference_image', 'reference_name', 'relative_captions'] when split == test | |
""" | |
def __init__(self, dataset_path: Union[Path, str], split: Literal['train', 'val', 'test'], dress_types: List[str], | |
mode: Literal['relative', 'classic'], preprocess: callable, no_duplicates: Optional[bool] = False): | |
""" | |
:param dataset_path: path to the FashionIQ dataset | |
:param split: dataset split, should be in ['train, 'val', 'test'] | |
:param dress_types: list of fashionIQ categories, each category should be in ['dress', 'shirt', 'toptee'] | |
:param mode: dataset mode, should be in ['relative', 'classic']: | |
- In 'classic' mode the dataset yield a dict with keys ['image', 'image_name'] | |
- In 'relative' mode the dataset yield dict with keys: | |
- ['reference_image', 'reference_name', 'target_image', 'target_name', 'relative_captions'] | |
when split in ['train', 'val'] | |
- ['reference_image', 'reference_name', 'relative_captions'] when split == test | |
:param preprocess: function which preprocesses the image | |
:param no_duplicates: if True, the dataset will not yield duplicate images in relative mode, does not affect classic mode | |
""" | |
dataset_path = Path(dataset_path) | |
self.dataset_path = dataset_path | |
self.mode = mode | |
self.dress_types = dress_types | |
self.split = split | |
self.no_duplicates = no_duplicates | |
# Validate the inputs | |
if mode not in ['relative', 'classic']: | |
raise ValueError("mode should be in ['relative', 'classic']") | |
if split not in ['test', 'train', 'val']: | |
raise ValueError("split should be in ['test', 'train', 'val']") | |
for dress_type in dress_types: | |
if dress_type not in ['dress', 'shirt', 'toptee']: | |
raise ValueError("dress_type should be in ['dress', 'shirt', 'toptee']") | |
self.preprocess = preprocess | |
# get triplets made by (reference_image, target_image, a pair of relative captions) | |
self.triplets: List[dict] = [] | |
for dress_type in dress_types: | |
with open(dataset_path / 'captions' / f'cap.{dress_type}.{split}.json') as f: | |
self.triplets.extend(json.load(f)) | |
# Remove duplicats from | |
if self.no_duplicates: | |
seen = set() | |
new_triplets = [] | |
for triplet in self.triplets: | |
if triplet['candidate'] not in seen: | |
seen.add(triplet['candidate']) | |
new_triplets.append(triplet) | |
self.triplets = new_triplets | |
# get the image names | |
self.image_names: list = [] | |
for dress_type in dress_types: | |
with open(dataset_path / 'image_splits' / f'split.{dress_type}.{split}.json') as f: | |
self.image_names.extend(json.load(f)) | |
print(f"FashionIQ {split} - {dress_types} dataset in {mode} mode initialized") | |
def __getitem__(self, index) -> dict: | |
try: | |
if self.mode == 'relative': | |
relative_captions = self.triplets[index]['captions'] | |
reference_name = self.triplets[index]['candidate'] | |
if self.split in ['train', 'val']: | |
reference_image_path = self.dataset_path / 'images' / f"{reference_name}.jpg" | |
reference_image = self.preprocess(PIL.Image.open(reference_image_path), return_tensors='pt')['pixel_values'][0] | |
target_name = self.triplets[index]['target'] | |
target_image_path = self.dataset_path / 'images' / f"{target_name}.jpg" | |
target_image = self.preprocess(PIL.Image.open(target_image_path), return_tensors='pt')['pixel_values'][0] | |
return { | |
'reference_image': reference_image, | |
'reference_name': reference_name, | |
'target_image': target_image, | |
'target_name': target_name, | |
'relative_captions': relative_captions | |
} | |
elif self.split == 'test': | |
reference_image_path = self.dataset_path / 'images' / f"{reference_name}.jpg" | |
reference_image = self.preprocess(PIL.Image.open(reference_image_path), return_tensors='pt')['pixel_values'][0] | |
return { | |
'reference_image': reference_image, | |
'reference_name': reference_name, | |
'relative_captions': relative_captions | |
} | |
elif self.mode == 'classic': | |
image_name = self.image_names[index] | |
image_path = self.dataset_path / 'images' / f"{image_name}.jpg" | |
image = self.preprocess(PIL.Image.open(image_path), return_tensors='pt')['pixel_values'][0] | |
return { | |
'image': image, | |
'image_name': image_name | |
} | |
else: | |
raise ValueError("mode should be in ['relative', 'classic']") | |
except Exception as e: | |
print(f"Exception: {e}") | |
def __len__(self): | |
if self.mode == 'relative': | |
return len(self.triplets) | |
elif self.mode == 'classic': | |
return len(self.image_names) | |
else: | |
raise ValueError("mode should be in ['relative', 'classic']") | |
class CIRRDataset(Dataset): | |
""" | |
Copy-paste from https://github.com/miccunifi/SEARLE/blob/main/src/datasets.py | |
CIRR dataset class for PyTorch dataloader. | |
The dataset can be used in 'relative' or 'classic' mode: | |
- In 'classic' mode the dataset yield a dict with keys ['image', 'image_name'] | |
- In 'relative' mode the dataset yield dict with keys: | |
- ['reference_image', 'reference_name', 'target_image', 'target_name', 'relative_caption', 'group_members'] | |
when split in ['train', 'val'] | |
- ['reference_image', 'reference_name' 'relative_caption', 'group_members', 'pair_id'] when split == test | |
""" | |
def __init__(self, dataset_path: Union[Path, str], split: Literal['train', 'val', 'test'], | |
mode: Literal['relative', 'classic'], preprocess: callable, no_duplicates: Optional[bool] = False): | |
""" | |
:param dataset_path: path to the CIRR dataset | |
:param split: dataset split, should be in ['train', 'val', 'test'] | |
:param mode: dataset mode, should be in ['relative', 'classic']: | |
- In 'classic' mode the dataset yield a dict with keys ['image', 'image_name'] | |
- In 'relative' mode the dataset yield dict with keys: | |
- ['reference_image', 'reference_name', 'target_image', 'target_name', 'relative_caption', | |
'group_members'] when split in ['train', 'val'] | |
- ['reference_image', 'reference_name' 'relative_caption', 'group_members', 'pair_id'] when split == test | |
:param preprocess: function which preprocesses the image | |
:param no_duplicates: if True, the dataset will not yield duplicate images in relative mode, does not affect classic mode | |
""" | |
dataset_path = Path(dataset_path) | |
self.dataset_path = dataset_path | |
self.preprocess = preprocess | |
self.mode = mode | |
self.split = split | |
self.no_duplicates = no_duplicates | |
if split == "test": | |
split = "test1" | |
self.split = "test1" | |
# Validate inputs | |
if split not in ['test1', 'train', 'val']: | |
raise ValueError("split should be in ['test1', 'train', 'val']") | |
if mode not in ['relative', 'classic']: | |
raise ValueError("mode should be in ['relative', 'classic']") | |
# get triplets made by (reference_image, target_image, relative caption) | |
with open(dataset_path / 'cirr' / 'captions' / f'cap.rc2.{split}.json') as f: | |
self.triplets = json.load(f) | |
# Remove duplicates from triplets | |
if self.no_duplicates: | |
seen = set() | |
new_triplets = [] | |
for triplet in self.triplets: | |
if triplet['reference'] not in seen: | |
seen.add(triplet['reference']) | |
new_triplets.append(triplet) | |
self.triplets = new_triplets | |
# get a mapping from image name to relative path | |
with open(dataset_path / 'cirr' / 'image_splits' / f'split.rc2.{split}.json') as f: | |
self.name_to_relpath = json.load(f) | |
print(f"CIRR {split} dataset in {mode} mode initialized") | |
def __getitem__(self, index) -> dict: | |
try: | |
if self.mode == 'relative': | |
group_members = self.triplets[index]['img_set']['members'] | |
reference_name = self.triplets[index]['reference'] | |
relative_caption = self.triplets[index]['caption'] | |
if self.split in ['train', 'val']: | |
reference_image_path = self.dataset_path / self.name_to_relpath[reference_name] | |
reference_image = self.preprocess(PIL.Image.open(reference_image_path), return_tensors='pt')['pixel_values'][0] | |
target_hard_name = self.triplets[index]['target_hard'] | |
target_image_path = self.dataset_path / self.name_to_relpath[target_hard_name] | |
target_image = self.preprocess(PIL.Image.open(target_image_path), return_tensors='pt')['pixel_values'][0] | |
return { | |
'reference_image': reference_image, | |
'reference_name': reference_name, | |
'target_image': target_image, | |
'target_name': target_hard_name, | |
'relative_caption': relative_caption, | |
'group_members': group_members | |
} | |
elif self.split == 'test1': | |
pair_id = self.triplets[index]['pairid'] | |
reference_image_path = self.dataset_path / self.name_to_relpath[reference_name] | |
reference_image = self.preprocess(PIL.Image.open(reference_image_path), return_tensors='pt')['pixel_values'][0] | |
return { | |
'reference_image': reference_image, | |
'reference_name': reference_name, | |
'relative_caption': relative_caption, | |
'group_members': group_members, | |
'pair_id': pair_id | |
} | |
elif self.mode == 'classic': | |
image_name = list(self.name_to_relpath.keys())[index] | |
image_path = self.dataset_path / self.name_to_relpath[image_name] | |
im = PIL.Image.open(image_path) | |
image = self.preprocess(im, return_tensors='pt')['pixel_values'][0] | |
return { | |
'image': image, | |
'image_name': image_name | |
} | |
else: | |
raise ValueError("mode should be in ['relative', 'classic']") | |
except Exception as e: | |
print(f"Exception: {e}") | |
def __len__(self): | |
if self.mode == 'relative': | |
return len(self.triplets) | |
elif self.mode == 'classic': | |
return len(self.name_to_relpath) | |
else: | |
raise ValueError("mode should be in ['relative', 'classic']") | |
class CIRCODataset(Dataset): | |
""" | |
Copy-paste from https://github.com/miccunifi/SEARLE/blob/main/src/datasets.py | |
CIRCO dataset class for PyTorch. | |
The dataset can be used in 'relative' or 'classic' mode: | |
- In 'classic' mode the dataset yield a dict with keys ['image', 'image_name'] | |
- In 'relative' mode the dataset yield dict with keys: | |
- ['reference_image', 'reference_name', 'target_image', 'target_name', 'relative_captions', 'shared_concept', | |
'gt_img_ids', 'query_id'] when split == 'val' | |
- ['reference_image', 'reference_name', 'relative_captions', 'shared_concept', 'query_id'] when split == test | |
""" | |
def __init__(self, dataset_path: Union[str, Path], split: Literal['val', 'test'], | |
mode: Literal['relative', 'classic'], preprocess: callable): | |
""" | |
Args: | |
dataset_path (Union[str, Path]): path to CIRCO dataset | |
split (str): dataset split, should be in ['test', 'val'] | |
mode (str): dataset mode, should be in ['relative', 'classic'] | |
preprocess (callable): function which preprocesses the image | |
""" | |
# Set dataset paths and configurations | |
dataset_path = Path(dataset_path) | |
self.mode = mode | |
self.split = split | |
self.preprocess = preprocess | |
self.data_path = dataset_path | |
# Ensure input arguments are valid | |
if mode not in ['relative', 'classic']: | |
raise ValueError("mode should be in ['relative', 'classic']") | |
if split not in ['test', 'val']: | |
raise ValueError("split should be in ['test', 'val']") | |
# Load COCO images information | |
with open(dataset_path / 'COCO2017_unlabeled' / "annotations" / "image_info_unlabeled2017.json", "r") as f: | |
imgs_info = json.load(f) | |
self.img_paths = [dataset_path / 'COCO2017_unlabeled' / "unlabeled2017" / img_info["file_name"] for img_info in | |
imgs_info["images"]] | |
self.img_ids = [img_info["id"] for img_info in imgs_info["images"]] | |
self.img_ids_indexes_map = {str(img_id): i for i, img_id in enumerate(self.img_ids)} | |
# get CIRCO annotations | |
with open(dataset_path / 'annotations' / f'{split}.json', "r") as f: | |
self.annotations: List[dict] = json.load(f) | |
# Get maximum number of ground truth images (for padding when loading the images) | |
self.max_num_gts = 23 # Maximum number of ground truth images | |
print(f"CIRCODataset {split} dataset in {mode} mode initialized") | |
def get_target_img_ids(self, index) -> Dict[str, int]: | |
""" | |
Returns the id of the target image and ground truth images for a given query | |
Args: | |
index (int): id of the query | |
Returns: | |
Dict[str, int]: dictionary containing target image id and a list of ground truth image ids | |
""" | |
return { | |
'target_img_id': self.annotations[index]['target_img_id'], | |
'gt_img_ids': self.annotations[index]['gt_img_ids'] | |
} | |
def __getitem__(self, index) -> dict: | |
""" | |
Returns a specific item from the dataset based on the index. | |
In 'classic' mode, the dataset yields a dictionary with the following keys: [img, img_id] | |
In 'relative' mode, the dataset yields dictionaries with the following keys: | |
- [reference_img, reference_img_id, target_img, target_img_id, relative_caption, shared_concept, gt_img_ids, | |
query_id] | |
if split == val | |
- [reference_img, reference_img_id, relative_caption, shared_concept, query_id] if split == test | |
""" | |
if self.mode == 'relative': | |
# Get the query id | |
query_id = str(self.annotations[index]['id']) | |
# Get relative caption and shared concept | |
relative_caption = self.annotations[index]['relative_caption'] | |
shared_concept = self.annotations[index]['shared_concept'] | |
# Get the reference image | |
reference_img_id = str(self.annotations[index]['reference_img_id']) | |
reference_img_path = self.img_paths[self.img_ids_indexes_map[reference_img_id]] | |
reference_img = self.preprocess(PIL.Image.open(reference_img_path), return_tensors='pt')['pixel_values'][0] | |
if self.split == 'val': | |
# Get the target image and ground truth images | |
target_img_id = str(self.annotations[index]['target_img_id']) | |
gt_img_ids = [str(x) for x in self.annotations[index]['gt_img_ids']] | |
target_img_path = self.img_paths[self.img_ids_indexes_map[target_img_id]] | |
target_img = self.preprocess(PIL.Image.open(target_img_path), return_tensors='pt')['pixel_values'][0] | |
# Pad ground truth image IDs with zeros for collate_fn | |
gt_img_ids += [''] * (self.max_num_gts - len(gt_img_ids)) | |
return { | |
'reference_image': reference_img, | |
'reference_name': reference_img_id, | |
'target_image': target_img, | |
'target_name': target_img_id, | |
'relative_caption': relative_caption, | |
'shared_concept': shared_concept, | |
'gt_img_ids': gt_img_ids, | |
'query_id': query_id, | |
} | |
elif self.split == 'test': | |
return { | |
'reference_image': reference_img, | |
'reference_name': reference_img_id, | |
'relative_caption': relative_caption, | |
'shared_concept': shared_concept, | |
'query_id': query_id, | |
} | |
elif self.mode == 'classic': | |
# Get image ID and image path | |
img_id = str(self.img_ids[index]) | |
img_path = self.img_paths[index] | |
# Preprocess image and return | |
img = self.preprocess(PIL.Image.open(img_path), return_tensors='pt')['pixel_values'][0] | |
return { | |
'image': img, | |
'image_name': img_id | |
} | |
def __len__(self): | |
""" | |
Returns the length of the dataset. | |
""" | |
if self.mode == 'relative': | |
return len(self.annotations) | |
elif self.mode == 'classic': | |
return len(self.img_ids) | |
else: | |
raise ValueError("mode should be in ['relative', 'classic']") | |