|
import jsonlines |
|
from tqdm import tqdm |
|
import random |
|
import json |
|
import os |
|
from multiprocessing import Pool |
|
from functools import partial |
|
import emoji |
|
|
|
import argparse |
|
|
|
def clean_span(span): |
|
span = span.rstrip() |
|
span = span.replace('"', "'").replace('\"', "'").replace('β', "'").replace('β', "'") |
|
span = span.replace('β', "'").replace('β', "'").replace('β', "β") |
|
if span.endswith('/') or span.endswith('.'): |
|
span = span[:-1] |
|
return span |
|
|
|
def check_caption(cap): |
|
check_anno = cap["caption"].rstrip()[:-1] |
|
if not str.isascii(check_anno): |
|
return False |
|
|
|
check_list = {"βοΈ", "-", ",", "Β ", "*", "/", "$", "[CLS]", "[SEP]", "?"} |
|
for ch in check_list: |
|
if ch in check_anno: |
|
return False |
|
if '.' in check_anno[:-1]: |
|
return False |
|
if emoji.emoji_count(check_anno): |
|
print(check_anno) |
|
return False |
|
return True |
|
|
|
def get_regions(nc, anno): |
|
h = anno["height"] |
|
w = anno["width"] |
|
phrase = clean_span(anno["caption"][int(nc[0]):int(nc[1])]) |
|
bbox = [round(nc[2]*w,2), round(nc[3]*h,2), round(nc[4]*w,2), round(nc[5]*h,2)] |
|
return { |
|
"bbox": bbox, |
|
"phrase": phrase |
|
} |
|
|
|
|
|
def prepare_list(file_name: str, random_samples): |
|
with open(file_name, "r") as f: |
|
metas = [line.strip() for line in f] |
|
num_of_files = len(metas) |
|
print(num_of_files) |
|
metas = random.sample(metas, random_samples) |
|
num_of_files = len(metas) |
|
print("after sample:", num_of_files) |
|
return metas, num_of_files |
|
|
|
|
|
def process_item(file, args): |
|
with open(os.path.join(args.root, file)) as f: |
|
anno = json.load(f) |
|
if not check_caption(anno): |
|
return None |
|
noun_chunks = anno['noun_chunks'] |
|
ref_exps = anno['ref_exps'] |
|
regions = [] |
|
random_num = random.random() |
|
if random_num > 0.5: |
|
for nc in noun_chunks: |
|
region = get_regions(nc, anno) |
|
if str.isascii(region["phrase"]): |
|
regions.append(region) |
|
else: |
|
for re in ref_exps: |
|
region = get_regions(re, anno) |
|
if str.isascii(region["phrase"]): |
|
regions.append(region) |
|
if len(regions) < args.min_phrase: |
|
return None |
|
odvg_anno = { |
|
"filename": f'{file.split(".")[0]}.jpg', |
|
"height": anno["height"], |
|
"width": anno["width"], |
|
"grounding": { |
|
"caption": clean_span(anno["caption"]), |
|
"regions": regions |
|
} |
|
} |
|
return odvg_anno |
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="GRIT2ODVG List.") |
|
parser.add_argument("--input_file", type=str, required=True) |
|
parser.add_argument("--root", type=str, default="", help="Source image root") |
|
parser.add_argument("--output_file", type=str, default="girt_14m_odvg.jsonl") |
|
parser.add_argument("--random_samples", type=int, default=200000) |
|
parser.add_argument("--chunk_or_ref", type=float, default=0.5) |
|
parser.add_argument("--min_phrase", type=int, default=6) |
|
parser.add_argument("--process_num", type=int, default=10, help="the number of processes") |
|
args = parser.parse_args() |
|
print(args) |
|
metas, metas_len = prepare_list(args.input_file, args.random_samples) |
|
odvg_anno = [] |
|
func = partial(process_item, args=args) |
|
with Pool(processes=args.process_num) as pool: |
|
for result in tqdm(pool.imap(func=func, iterable=metas), total=len(metas)): |
|
odvg_anno.append(result) |
|
odvg_anno = list(filter(None, odvg_anno)) |
|
with jsonlines.open(args.output_file, mode="w") as fwriter: |
|
fwriter.write_all(odvg_anno) |