|
import sys |
|
|
|
sys.path.append(".") |
|
|
|
import logging |
|
import os |
|
from typing import Optional, Dict |
|
|
|
import hydra |
|
import torch |
|
from hydra.utils import instantiate |
|
from datasets import DatasetDict, load_dataset, IterableDatasetDict |
|
from omegaconf import DictConfig, OmegaConf |
|
from src.data.transforms import SamCaptionerDataTransform |
|
from src.data.collator import SamCaptionerDataCollator |
|
from src.arguments import Arguments, global_setup, SAMCaptionerModelArguments, SCAModelArguments, SCAModelBaseArguments |
|
from src.models.sam_captioner import SAMCaptionerConfig, SAMCaptionerModel, SAMCaptionerProcessor |
|
from src.models.sca import ScaProcessor |
|
|
|
from transformers.trainer_utils import get_last_checkpoint |
|
from transformers import set_seed, Trainer |
|
import gradio as gr |
|
from dataclasses import dataclass |
|
import numpy as np |
|
from functools import partial |
|
import pandas as pd |
|
from src.train import prepare_datasets, prepare_data_transform, prepare_processor |
|
import pycocotools.mask |
|
from PIL import Image |
|
import dotenv |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@hydra.main(version_base="1.3", config_path="../../src/conf", config_name="conf") |
|
def main(args: DictConfig) -> None: |
|
|
|
|
|
logger.info(OmegaConf.to_yaml(args)) |
|
args, training_args, model_args = global_setup(args) |
|
|
|
|
|
set_seed(args.training.seed) |
|
|
|
|
|
train_dataset, eval_dataset = prepare_datasets(args) |
|
|
|
|
|
logger.info(f"Try to load sas_key from .env file: {dotenv.load_dotenv('.env')}.") |
|
use_auth_token = os.getenv("USE_AUTH_TOKEN", False) |
|
|
|
processor = prepare_processor(model_args, use_auth_token) |
|
|
|
train_dataset, eval_dataset = prepare_data_transform( |
|
training_args, model_args, train_dataset, eval_dataset, processor |
|
) |
|
|
|
|
|
|
|
global image_mean, image_std |
|
image_mean, image_std = ( |
|
processor.sam_processor.image_processor.image_mean, |
|
processor.sam_processor.image_processor.image_std, |
|
) |
|
|
|
def view_one_batch(dataset_split, batch_idx, dataset_type): |
|
if dataset_type == "before_transform": |
|
return _view_one_batch_before_transform(dataset_split, batch_idx, dataset_type) |
|
elif dataset_type == "after_transform": |
|
return _view_one_batch_after_transform(dataset_split, batch_idx, dataset_type) |
|
else: |
|
raise ValueError(f"Unknown type of sample: {dataset_type}") |
|
|
|
def _view_one_batch_before_transform(dataset_split, batch_idx, dataset_type): |
|
sample = dataset_split[batch_idx] |
|
image = sample["image"] |
|
|
|
text = f"dataset_type: {dataset_type}\nsample_id: {batch_idx}\n" |
|
for k, v in sample.items(): |
|
if isinstance(v, (int, str)): |
|
text += f"{k}: {v}\n" |
|
|
|
regions = sample["regions"] |
|
regions = pd.DataFrame(regions) |
|
regions.sort_values(by=["region_id"], ascending=True, inplace=True) |
|
return image, text, regions |
|
|
|
def _view_one_batch_after_transform(dataset_split, batch_idx, dataset_type): |
|
sample = dataset_split[batch_idx] |
|
image = sample["images"] |
|
image = sample["pixel_values"] |
|
image_mean_tensor = torch.tensor(image_mean).view(3, 1, 1) |
|
image_std_tensor = torch.tensor(image_std).view(3, 1, 1) |
|
image = image * image_std_tensor + image_mean_tensor |
|
image = image.clamp(0, 1) * 255 |
|
image = image.permute(1, 2, 0).numpy().astype(np.uint8) |
|
|
|
PRINT_VALUE_KEYS = ["original_sizes", "reshaped_input_sizes"] |
|
text = f"dataset_type: {dataset_type}\nsample_id: {batch_idx}\n" |
|
for k, v in sample.items(): |
|
text += f"{k}:\t{type(v)}\t" |
|
if k in PRINT_VALUE_KEYS: |
|
text += f"{v}\n" |
|
elif isinstance(v, str): |
|
text += f"{v}\n" |
|
elif isinstance(v, torch.Tensor): |
|
text += f"{v.shape}\n" |
|
elif isinstance(v, list): |
|
text += f"{len(v)}\n" |
|
elif isinstance(v, np.ndarray): |
|
text += f"{v.shape}\n" |
|
else: |
|
try: |
|
text += f"{v.size}\n" |
|
except AttributeError: |
|
text += f"{v}\n" |
|
|
|
REGION_KEYS = [ |
|
"input_boxes", |
|
"metadata_input_boxes", |
|
"metadata_image_id", |
|
"metadata_region_id", |
|
"metadata_captions", |
|
] |
|
pd_series = [] |
|
for region_tensor_key in REGION_KEYS: |
|
region_tensor = sample[region_tensor_key] |
|
|
|
if region_tensor_key == "input_boxes": |
|
if isinstance(region_tensor, torch.Tensor): |
|
region_tensor = region_tensor.long() |
|
elif isinstance(region_tensor, np.ndarray): |
|
region_tensor = region_tensor.astype(np.int64) |
|
if isinstance(region_tensor, (torch.Tensor, np.ndarray)): |
|
region_list = region_tensor.tolist() |
|
elif isinstance(region_tensor, list): |
|
region_list = region_tensor |
|
else: |
|
raise ValueError(f"Unknown type of region_tensor: {type(region_tensor)}") |
|
pd_series.append(pd.Series(region_list, name=region_tensor_key)) |
|
regions = pd.concat(pd_series, axis=1) |
|
regions.sort_values(by=["metadata_region_id"], ascending=True, inplace=True) |
|
|
|
return image, text, regions |
|
|
|
def view_one_region(image, data_frame, output_chioce_radio, dataset_type, evt: gr.SelectData): |
|
if dataset_type == "before_transform": |
|
return _view_one_region_before_transform(image, data_frame, output_chioce_radio, evt) |
|
elif dataset_type == "after_transform": |
|
return _view_one_region_after_transform(image, data_frame, output_chioce_radio, evt) |
|
else: |
|
raise ValueError(f"Unknown type of sample: {dataset_type}") |
|
|
|
def _view_one_region_before_transform(image, data_frame, output_chioce_radio, evt): |
|
row_id, _ = evt.index |
|
region = data_frame.iloc[row_id] |
|
if output_chioce_radio == "segmentation" and region.get("mask", None) is not None: |
|
annot = region["mask"] |
|
annot = pycocotools.mask.decode(annot) |
|
elif output_chioce_radio == "segmentation" and region.get("mask", None) is None: |
|
x, y, w, h = region["x"], region["y"], region["width"], region["height"] |
|
x2, y2 = x + w, y + h |
|
annot = [x, y, x2, y2] |
|
elif output_chioce_radio == "bbox": |
|
x, y, w, h = region["x"], region["y"], region["width"], region["height"] |
|
x2, y2 = x + w, y + h |
|
annot = [x, y, x2, y2] |
|
else: |
|
raise ValueError(f"Unknown output_chioce_radio: {output_chioce_radio}") |
|
|
|
phrases = [f"{idx}: {phrase}" for idx, phrase in enumerate(region["phrases"])] |
|
phrases = "; ".join(phrases) |
|
return image, [[annot, phrases]] |
|
|
|
def _view_one_region_after_transform(image, data_frame, output_chioce_radio, evt): |
|
row_id, _ = evt.index |
|
region = data_frame.iloc[row_id] |
|
if output_chioce_radio == "segmentation" and region.get("mask", None) is not None: |
|
raise NotImplementedError("TODO: implement segmentation for after_transform") |
|
elif output_chioce_radio == "segmentation" and region.get("mask", None) is None: |
|
annot = list(map(int, region["input_boxes"])) |
|
elif output_chioce_radio == "bbox": |
|
annot = list(map(int, region["input_boxes"])) |
|
else: |
|
raise ValueError(f"Unknown output_chioce_radio: {output_chioce_radio}") |
|
|
|
phrases = region["metadata_captions"] |
|
if not isinstance(phrases[0], list): |
|
phrases = [phrases] |
|
phrases = [f"{idx}: {phrase}" for idx, phrase in enumerate(phrases)] |
|
phrases = "; ".join(phrases) |
|
return image, [[annot, phrases]] |
|
|
|
def get_gr_frame(frame_name, dataset_split): |
|
dataset_type = "before_transform" if dataset_split[0].get("images", None) is None else "after_transform" |
|
dataset_type = gr.Variable(dataset_type) |
|
with gr.Accordion(label=frame_name) as frame: |
|
batch_idx = gr.Slider(minimum=0, maximum=len(dataset_split), step=1, default=0) |
|
button = gr.Button(text="View the batch") |
|
output_chioce_radio = gr.Radio(["bbox", "segmentation"], value="bbox") |
|
|
|
image = gr.Image(height=500) |
|
text = gr.Textbox(lines=1) |
|
data_frame = gr.DataFrame() |
|
annotated_image = gr.AnnotatedImage(height=500) |
|
|
|
dataset_split = gr.Variable(dataset_split) |
|
button.click( |
|
view_one_batch, inputs=[dataset_split, batch_idx, dataset_type], outputs=[image, text, data_frame] |
|
) |
|
data_frame.select( |
|
view_one_region, |
|
inputs=[image, data_frame, output_chioce_radio, dataset_type], |
|
outputs=[annotated_image], |
|
) |
|
|
|
return frame |
|
|
|
with gr.Blocks() as app: |
|
get_gr_frame("train", train_dataset) |
|
for eval_data_k, eval_data_v in eval_dataset.items(): |
|
get_gr_frame(f"validate-{eval_data_k}", eval_data_v) |
|
app.launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|