LuojiaHOG / get_data_by_image_id.py
aleo1's picture
Upload 41 files
bb6012a verified
raw
history blame contribute delete
No virus
3.74 kB
import json
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
def read_json(file_name, suppress_console_info=False):
with open(file_name, 'r') as f:
data = json.load(f)
if not suppress_console_info:
print("Read from:", file_name)
return data
def get_file_names(data, imgs_folder, feature_folder, suppress_console_info=False):
image_file_names = {}
feature_pathes = {}
captions = {}
labels = {}
lats = {}
lons = {}
for img in data['images']:
image_name = img["image_name"]
sample_id = img["sample_id"]
image_id = f'{sample_id}_{image_name}'
path_data = imgs_folder + f'{sample_id}/{image_name}'
feature_data = feature_folder + f'{sample_id}/{image_name}.npy'
# image_file_name.append(path_data)
# caption.append(img["description"])
# label.append(img["labels"])
# lat.append(img["lat"])
# lon.append(img["lon"])
image_file_names[image_id] = path_data
feature_pathes[image_id] = feature_data
captions[image_id] = img["description"]
labels[image_id] = img["labels"]
lats[image_id] = img["lat"]
lons[image_id] = img["lon"]
return image_file_names, feature_pathes, captions, labels, lats, lons
def get_data(image_file_names, captions, feature_pathes, labels, lats, lons, image_id):
image_file_name = image_file_names[image_id]
feature_path = feature_pathes[image_id]
caption = captions[image_id]
label = labels[image_id]
lat = lats[image_id]
lon = lons[image_id]
return image_file_name, feature_path, caption, label, lat, lon
def read_by_image_id(data_dir, imgs_folder, feature_folder, image_id=None):
'''
return:
img
img_ -> transform(img)
caption
image_feature -> tensor
label
label_en -> text of labels
lat
lon
'''
data_info = read_json(data_dir)
image_file_names, image_features_path, captions, labels, lats, lons = get_file_names(data_info, imgs_folder, feature_folder)
image_file_name, image_feature_path, caption, label, lat, lon = get_data(image_file_names, captions, image_features_path, labels, lats, lons, image_id)
label_en = []
label131 = data_info['labels']
for lable_name in label131.keys():
label_id = label131[lable_name]
for label_singel in label:
if label_singel == label_id:
label_en.append(lable_name)
image_feature = np.load(image_feature_path)
img = Image.open(image_file_name).convert('RGB')
transform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
if transform is not None:
img_ = np.array(transform(img))
else:
img_ = np.array(img)
img_ = torch.from_numpy(img_.astype('float32'))
return img, img_, caption, image_feature, label, label_en, lat, lon
# test
data_dir = '/data02/xy/dataEngine/json_data/merged_output_combined_9w_resplit.json'
imgs_folder = '/data02/xy/Clip-hash//datasets/image/'
feature_folder = '/data02/xy/Clip-hash/image_feature/georsclip_21_r0.9_fpn/'
image_id = 'sample44_889.jpg'
# img, img_, caption, image_feature, label, label_en, lat, lon = read_by_image_id(data_dir, imgs_folder, feature_folder, image_id)
# print(img, img_, caption, image_feature, label, label_en, lat, lon)