|
|
|
|
|
import sys |
|
|
|
|
|
sys.path.append(".") |
|
|
import base64 |
|
|
import io |
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import os.path as osp |
|
|
|
|
|
import datasets |
|
|
import hydra |
|
|
from hydra.utils import instantiate |
|
|
import numpy as np |
|
|
import tqdm |
|
|
from hydra.core.hydra_config import HydraConfig |
|
|
from hydra.core.utils import configure_log |
|
|
from omegaconf import DictConfig, OmegaConf |
|
|
from PIL import Image |
|
|
import pycocotools.mask |
|
|
from utils.git_utils import TSVWriter |
|
|
from src.arguments import Arguments |
|
|
from src.train import prepare_datasets |
|
|
import torch |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def ensure_correct_bbox(x, y, x2, y2, image_w, image_h): |
|
|
x = max(0, x) |
|
|
y = max(0, y) |
|
|
x2 = min(image_w - 1, x2) |
|
|
y2 = min(image_h - 1, y2) |
|
|
return x, y, x2, y2 |
|
|
|
|
|
|
|
|
def gen_rows(dataset, img_enc_type="PNG", num_rows=None): |
|
|
region_cnt = 0 |
|
|
sample_cnt = 0 |
|
|
sent_cnt = 0 |
|
|
|
|
|
dataset = next(iter(dataset.values())) |
|
|
dataloader = torch.utils.data.DataLoader( |
|
|
dataset, batch_size=1, shuffle=False, num_workers=10, collate_fn=lambda x: x[0] |
|
|
) |
|
|
tbar = tqdm.tqdm(dataloader) |
|
|
for sample_idx, sample in enumerate(tbar): |
|
|
if num_rows is not None and sample_idx >= num_rows: |
|
|
break |
|
|
image = sample["image"] |
|
|
image = np.array(image) |
|
|
image_id = sample["image_id"] |
|
|
image_h, image_w = image.shape[:2] |
|
|
|
|
|
for annot_idx, region in enumerate(sample["regions"]): |
|
|
mask = region.get("mask", None) |
|
|
if mask is None: |
|
|
area = -1 |
|
|
else: |
|
|
area = pycocotools.mask.area(mask) |
|
|
|
|
|
seg_id = region["region_id"] |
|
|
x, y, w, h = region["x"], region["y"], region["width"], region["height"] |
|
|
x2, y2 = x + w, y + h |
|
|
try: |
|
|
x, y, x2, y2 = ensure_correct_bbox(x, y, x2, y2, image_w, image_h) |
|
|
if x >= x2 or y >= y2: |
|
|
raise ValueError( |
|
|
f"Invalid bbox: {x, y, x2, y2} (x, y, x2, y2), at {sample_idx}-{image_id}-{annot_idx}-{seg_id} (sample_idx-image_id-annot_idx-seg_id)" |
|
|
) |
|
|
|
|
|
sub_image = image[y:y2, x:x2].copy() |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"[Crop] Error cropping image: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
sub_image_name = f"{image_id}-{annot_idx}-{seg_id}" |
|
|
buffer = io.BytesIO() |
|
|
Image.fromarray(sub_image).save(buffer, format=img_enc_type) |
|
|
bytes = buffer.getvalue() |
|
|
base64_bytes = base64.b64encode(bytes) |
|
|
|
|
|
|
|
|
cap = region["phrases"] |
|
|
sent_cnt += len(cap) |
|
|
region_cnt += 1 |
|
|
|
|
|
yield sub_image_name, base64_bytes, cap, int(area), [int(image_h), int(image_w)] |
|
|
|
|
|
sample_cnt += 1 |
|
|
tbar.set_description(f"Already processing {sample_cnt} samples and {region_cnt} regions.") |
|
|
|
|
|
logger.info(f"Total samples: {sample_cnt}, total regions: {region_cnt}, total sents: {sent_cnt}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@hydra.main(version_base="1.3", config_path="../../src/conf", config_name="conf") |
|
|
def main(cfg: Arguments): |
|
|
output_dir = cfg.training.output_dir |
|
|
if output_dir is None: |
|
|
raise ValueError("Output dir is None") |
|
|
output_dir = osp.join(output_dir, "region_img_annot_caption") |
|
|
if not osp.exists(output_dir): |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
_, eval_dataset = prepare_datasets(cfg) |
|
|
eval_data = cfg.eval_data |
|
|
|
|
|
if len(eval_data) > 1: |
|
|
raise ValueError(f"Only one eval dataset is allowed, got {cfg.eval_data}") |
|
|
|
|
|
num_rows = cfg.get("num_rows", None) |
|
|
img_enc_type = cfg.get("img_enc_type", "PNG") |
|
|
force_overwrite = cfg.get("force_overwrite", False) |
|
|
|
|
|
data_path = eval_data[0].path |
|
|
data_path = osp.basename(data_path) |
|
|
data_name = eval_data[0].name |
|
|
data_split = eval_data[0].split |
|
|
|
|
|
if data_split is None: |
|
|
raise ValueError( |
|
|
"Data split is None. " |
|
|
f"Please specify the split in: {list(eval_dataset.keys())}. " |
|
|
f"e.g., +split={list(eval_dataset.keys())[0]}" |
|
|
) |
|
|
|
|
|
extract_from_one_split( |
|
|
output_dir, eval_dataset, num_rows, img_enc_type, data_split, data_path, data_name, force_overwrite |
|
|
) |
|
|
|
|
|
|
|
|
def extract_from_one_split( |
|
|
output_dir, dataset, num_rows, img_enc_type, data_split, data_path, data_name, force_overwrite |
|
|
): |
|
|
output_name = f"{data_path}-{data_name}-{data_split}" |
|
|
output_img_path = f"{output_dir}/{output_name}.region_img.tsv" |
|
|
output_cap_path = f"{output_dir}/{output_name}.region_cap.tsv" |
|
|
output_annot_path = f"{output_dir}/{output_name}.region_annot.tsv" |
|
|
|
|
|
logger.info(f"Output dir: {output_dir}, Data split: {data_split}") |
|
|
logger.info(f"\tOutput img path: {output_img_path}") |
|
|
logger.info(f"\tOutput cap path: {output_cap_path}") |
|
|
logger.info(f"\tOutput annot path: {output_annot_path}") |
|
|
|
|
|
if osp.exists(output_img_path) or osp.exists(output_cap_path) or osp.exists(output_annot_path): |
|
|
if force_overwrite is True: |
|
|
logger.info( |
|
|
f"Force overwrite output tsv files:\n\t{output_img_path},\n\t{output_cap_path},\n\t{output_annot_path}" |
|
|
) |
|
|
else: |
|
|
raise ValueError( |
|
|
f"Output tsv files exist. Skip:\n\t{output_img_path},\n\t{output_cap_path},\n\t{output_annot_path}" |
|
|
) |
|
|
|
|
|
with TSVWriter(output_img_path) as img_tsv_writer, TSVWriter(output_cap_path) as cap_tsv_writer, TSVWriter( |
|
|
output_annot_path |
|
|
) as annot_tsv_writer: |
|
|
for row in gen_rows(dataset, img_enc_type=img_enc_type, num_rows=num_rows): |
|
|
sub_image_name, base64_bytes, cap, area, image_size = row |
|
|
cap = json.dumps([dict(caption=cap)]) |
|
|
img_tsv_writer.write((sub_image_name, base64_bytes)) |
|
|
cap_tsv_writer.write((sub_image_name, cap)) |
|
|
annot = json.dumps(dict(area=area, image_size=image_size)) |
|
|
annot_tsv_writer.write((sub_image_name, annot)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|