import json import numpy as np from PIL import Image import torch from torchvision import transforms def read_json(file_name, suppress_console_info=False): with open(file_name, 'r') as f: data = json.load(f) if not suppress_console_info: print("Read from:", file_name) return data def get_file_names(data, imgs_folder, feature_folder, suppress_console_info=False): image_file_names = {} feature_pathes = {} captions = {} labels = {} lats = {} lons = {} for img in data['images']: image_name = img["image_name"] sample_id = img["sample_id"] image_id = f'{sample_id}_{image_name}' path_data = imgs_folder + f'{sample_id}/{image_name}' feature_data = feature_folder + f'{sample_id}/{image_name}.npy' # image_file_name.append(path_data) # caption.append(img["description"]) # label.append(img["labels"]) # lat.append(img["lat"]) # lon.append(img["lon"]) image_file_names[image_id] = path_data feature_pathes[image_id] = feature_data captions[image_id] = img["description"] labels[image_id] = img["labels"] lats[image_id] = img["lat"] lons[image_id] = img["lon"] return image_file_names, feature_pathes, captions, labels, lats, lons def get_data(image_file_names, captions, feature_pathes, labels, lats, lons, image_id): image_file_name = image_file_names[image_id] feature_path = feature_pathes[image_id] caption = captions[image_id] label = labels[image_id] lat = lats[image_id] lon = lons[image_id] return image_file_name, feature_path, caption, label, lat, lon def read_by_image_id(data_dir, imgs_folder, feature_folder, image_id=None): ''' return: img img_ -> transform(img) caption image_feature -> tensor label label_en -> text of labels lat lon ''' data_info = read_json(data_dir) image_file_names, image_features_path, captions, labels, lats, lons = get_file_names(data_info, imgs_folder, feature_folder) image_file_name, image_feature_path, caption, label, lat, lon = get_data(image_file_names, captions, image_features_path, labels, lats, lons, image_id) label_en = [] label131 = data_info['labels'] for lable_name in label131.keys(): label_id = label131[lable_name] for label_singel in label: if label_singel == label_id: label_en.append(lable_name) image_feature = np.load(image_feature_path) img = Image.open(image_file_name).convert('RGB') transform = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) ]) if transform is not None: img_ = np.array(transform(img)) else: img_ = np.array(img) img_ = torch.from_numpy(img_.astype('float32')) return img, img_, caption, image_feature, label, label_en, lat, lon # test data_dir = '/data02/xy/dataEngine/json_data/merged_output_combined_9w_resplit.json' imgs_folder = '/data02/xy/Clip-hash//datasets/image/' feature_folder = '/data02/xy/Clip-hash/image_feature/georsclip_21_r0.9_fpn/' image_id = 'sample44_889.jpg' # img, img_, caption, image_feature, label, label_en, lat, lon = read_by_image_id(data_dir, imgs_folder, feature_folder, image_id) # print(img, img_, caption, image_feature, label, label_en, lat, lon)