Spaces:
Runtime error
Runtime error
import getopt | |
import json | |
import os | |
# import numpy as np | |
import sys | |
from collections import OrderedDict | |
import datasets | |
import numpy as np | |
import torch | |
from modeling_frcnn import GeneralizedRCNN | |
from processing_image import Preprocess | |
from utils import Config | |
""" | |
USAGE: | |
``python extracting_data.py -i <img_dir> -o <dataset_file>.datasets <batch_size>`` | |
""" | |
TEST = False | |
CONFIG = Config.from_pretrained("unc-nlp/frcnn-vg-finetuned") | |
DEFAULT_SCHEMA = datasets.Features( | |
OrderedDict( | |
{ | |
"attr_ids": datasets.Sequence(length=CONFIG.MAX_DETECTIONS, feature=datasets.Value("float32")), | |
"attr_probs": datasets.Sequence(length=CONFIG.MAX_DETECTIONS, feature=datasets.Value("float32")), | |
"boxes": datasets.Array2D((CONFIG.MAX_DETECTIONS, 4), dtype="float32"), | |
"img_id": datasets.Value("int32"), | |
"obj_ids": datasets.Sequence(length=CONFIG.MAX_DETECTIONS, feature=datasets.Value("float32")), | |
"obj_probs": datasets.Sequence(length=CONFIG.MAX_DETECTIONS, feature=datasets.Value("float32")), | |
"roi_features": datasets.Array2D((CONFIG.MAX_DETECTIONS, 2048), dtype="float32"), | |
"sizes": datasets.Sequence(length=2, feature=datasets.Value("float32")), | |
"preds_per_image": datasets.Value(dtype="int32"), | |
} | |
) | |
) | |
class Extract: | |
def __init__(self, argv=sys.argv[1:]): | |
inputdir = None | |
outputfile = None | |
subset_list = None | |
batch_size = 1 | |
opts, args = getopt.getopt(argv, "i:o:b:s", ["inputdir=", "outfile=", "batch_size=", "subset_list="]) | |
for opt, arg in opts: | |
if opt in ("-i", "--inputdir"): | |
inputdir = arg | |
elif opt in ("-o", "--outfile"): | |
outputfile = arg | |
elif opt in ("-b", "--batch_size"): | |
batch_size = int(arg) | |
elif opt in ("-s", "--subset_list"): | |
subset_list = arg | |
assert inputdir is not None # and os.path.isdir(inputdir), f"{inputdir}" | |
assert outputfile is not None and not os.path.isfile(outputfile), f"{outputfile}" | |
if subset_list is not None: | |
with open(os.path.realpath(subset_list)) as f: | |
self.subset_list = {self._vqa_file_split()[0] for x in tryload(f)} | |
else: | |
self.subset_list = None | |
self.config = CONFIG | |
if torch.cuda.is_available(): | |
self.config.model.device = "cuda" | |
self.inputdir = os.path.realpath(inputdir) | |
self.outputfile = os.path.realpath(outputfile) | |
self.preprocess = Preprocess(self.config) | |
self.model = GeneralizedRCNN.from_pretrained("unc-nlp/frcnn-vg-finetuned", config=self.config) | |
self.batch = batch_size if batch_size != 0 else 1 | |
self.schema = DEFAULT_SCHEMA | |
def _vqa_file_split(self, file): | |
img_id = int(file.split(".")[0].split("_")[-1]) | |
filepath = os.path.join(self.inputdir, file) | |
return (img_id, filepath) | |
def file_generator(self): | |
batch = [] | |
for i, file in enumerate(os.listdir(self.inputdir)): | |
if self.subset_list is not None and i not in self.subset_list: | |
continue | |
batch.append(self._vqa_file_split(file)) | |
if len(batch) == self.batch: | |
temp = batch | |
batch = [] | |
yield list(map(list, zip(*temp))) | |
for i in range(1): | |
yield list(map(list, zip(*batch))) | |
def __call__(self): | |
# make writer | |
if not TEST: | |
writer = datasets.ArrowWriter(features=self.schema, path=self.outputfile) | |
# do file generator | |
for i, (img_ids, filepaths) in enumerate(self.file_generator): | |
images, sizes, scales_yx = self.preprocess(filepaths) | |
output_dict = self.model( | |
images, | |
sizes, | |
scales_yx=scales_yx, | |
padding="max_detections", | |
max_detections=self.config.MAX_DETECTIONS, | |
pad_value=0, | |
return_tensors="np", | |
location="cpu", | |
) | |
output_dict["boxes"] = output_dict.pop("normalized_boxes") | |
if not TEST: | |
output_dict["img_id"] = np.array(img_ids) | |
batch = self.schema.encode_batch(output_dict) | |
writer.write_batch(batch) | |
if TEST: | |
break | |
# finalizer the writer | |
if not TEST: | |
num_examples, num_bytes = writer.finalize() | |
print(f"Success! You wrote {num_examples} entry(s) and {num_bytes >> 20} mb") | |
def tryload(stream): | |
try: | |
data = json.load(stream) | |
try: | |
data = list(data.keys()) | |
except Exception: | |
data = [d["img_id"] for d in data] | |
except Exception: | |
try: | |
data = eval(stream.read()) | |
except Exception: | |
data = stream.read().split("\n") | |
return data | |
if __name__ == "__main__": | |
extract = Extract(sys.argv[1:]) | |
extract() | |
if not TEST: | |
dataset = datasets.Dataset.from_file(extract.outputfile) | |
# wala! | |
# print(np.array(dataset[0:2]["roi_features"]).shape) | |