SmallCapDemo / extract_features.py
RitaParadaRamos's picture
Upload 11 files
7c4b306
raw
history blame
No virus
1.91 kB
import os
import sys
import pandas as pd
import json
from tqdm import tqdm
from PIL import Image
import torch
from multiprocessing import Pool
import h5py
from transformers import logging
from transformers import CLIPFeatureExtractor, CLIPVisionModel
logging.set_verbosity_error()
data_dir = 'data/images/'
features_dir = 'features/'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder_name = 'openai/clip-vit-base-patch32'
feature_extractor = CLIPFeatureExtractor.from_pretrained(encoder_name)
clip_encoder = CLIPVisionModel.from_pretrained(encoder_name).to(device)
annotations = json.load(open('data/dataset_coco.json'))['images']
def load_data():
data = {'train': [], 'val': []}
for item in annotations:
file_name = item['filename'].split('_')[-1]
if item['split'] == 'train' or item['split'] == 'restval':
data['train'].append({'file_name': file_name, 'cocoid': item['cocoid']})
elif item['split'] == 'val':
data['val'].append({'file_name': file_name, 'cocoid': item['cocoid']})
return data
def encode_split(data, split):
df = pd.DataFrame(data[split])
bs = 256
h5py_file = h5py.File(features_dir + '{}.hdf5'.format(split), 'w')
for idx in tqdm(range(0, len(df), bs)):
cocoids = df['cocoid'][idx:idx + bs]
file_names = df['file_name'][idx:idx + bs]
images = [Image.open(data_dir + file_name).convert("RGB") for file_name in file_names]
with torch.no_grad():
pixel_values = feature_extractor(images, return_tensors='pt').pixel_values.to(device)
encodings = clip_encoder(pixel_values=pixel_values).last_hidden_state.cpu().numpy()
for cocoid, encoding in zip(cocoids, encodings):
h5py_file.create_dataset(str(cocoid), (50, 768), data=encoding)
data = load_data()
encode_split(data, 'train')
encode_split(data, 'val')