|
import logging |
|
import os |
|
import os.path as osp |
|
import socket |
|
import sys |
|
from dataclasses import dataclass, field |
|
from typing import Optional, List, Any, Tuple, Dict |
|
|
|
import datasets |
|
import torch |
|
import transformers |
|
from hydra.core.config_store import ConfigStore |
|
from omegaconf import DictConfig, OmegaConf |
|
from transformers import Seq2SeqTrainingArguments, TrainingArguments |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
class SCATrainingArguments(TrainingArguments): |
|
report_to: Any = field( |
|
default="none" |
|
) |
|
remove_unused_columns: bool = field(default=False) |
|
|
|
|
|
|
|
per_device_eval_batch_size: int = field(default=1) |
|
|
|
|
|
label_names: List[str] = field(default_factory=lambda: ["labels"]) |
|
|
|
|
|
predict_with_generate: bool = field(default=True) |
|
|
|
|
|
|
|
|
|
log_level: str = field(default="info") |
|
|
|
|
|
num_masks_per_sample: Optional[int] = field(default=None) |
|
|
|
max_train_samples: Optional[int] = field(default=None) |
|
max_eval_samples: Optional[int] = field(default=None) |
|
max_predict_samples: Optional[int] = field(default=None) |
|
|
|
|
|
output_log_dir: Optional[str] = field(default=None) |
|
|
|
|
|
do_inference: bool = field(default=False) |
|
|
|
|
|
evaluate_before_train: bool = field(default=False) |
|
|
|
|
|
trainable_params: Optional[List[str]] = field(default=None) |
|
custom_param_lrs: Dict[str, float] = field( |
|
default_factory=lambda: dict(), |
|
metadata={ |
|
"help": "custom param lrs, prefix: lr, e.g., language_model, prefix: lr, e.g., +training.custom_param_lrs='{language_model:0.1}'" |
|
}, |
|
) |
|
|
|
|
|
compute_metrics: Optional[bool] = field(default=None) |
|
|
|
|
|
|
|
|
|
data_transforms: Optional[Any] = field(default=None) |
|
|
|
|
|
|
|
|
|
data_collator: Optional[Any] = field(default=None) |
|
|
|
|
|
|
|
|
|
load_best_model_at_end: bool = field(default=True) |
|
|
|
save_total_limit: int = field(default=2) |
|
save_save_strategy: str = field(default="steps") |
|
evaluation_strategy: str = field(default="steps") |
|
|
|
|
|
generate_chunk_size: Optional[int] = field(default=None) |
|
|
|
|
|
prompt_types_to_ablate_on_vg: Optional[str] = field( |
|
default=None |
|
) |
|
|
|
_run_post_init: bool = field(default=False) |
|
|
|
def __post_init__(self): |
|
|
|
|
|
|
|
if self.report_to != "none": |
|
raise ValueError(f"report_to must be None, got {self.report_to}") |
|
if self.label_smoothing_factor != 0: |
|
raise ValueError( |
|
f"label_smoothing_factor must be 0 as the first output of the model is not language model logits, got {self.label_smoothing_factor}" |
|
) |
|
if self._run_post_init: |
|
if self.per_device_eval_batch_size != 1: |
|
raise ValueError( |
|
"per_device_eval_batch_size must be 1, " |
|
"since we cannot batchify different " |
|
"number of masks per sample during eval." |
|
) |
|
super().__post_init__() |
|
|
|
|
|
@dataclass |
|
class _Seq2SeqTrainingArguments(Seq2SeqTrainingArguments): |
|
|
|
|
|
debug: Any |
|
fsdp: Any |
|
|
|
generation_config: Any |
|
|
|
neftune_noise_alpha: Any = None |
|
sharded_ddp: Any = "" |
|
|
|
|
|
@dataclass |
|
class SCASeq2SeqTrainingArguments(SCATrainingArguments, _Seq2SeqTrainingArguments): |
|
pass |
|
|
|
|
|
@dataclass |
|
class ModelArguments: |
|
model_max_length: int = field(default=20) |
|
cache_dir: str = field(default=".model.cache") |
|
|
|
|
|
@dataclass |
|
class SAMCaptionerModelArguments(ModelArguments): |
|
sam_model_name_or_path: str = field(default="facebook/sam-vit-huge") |
|
captioner_model_name_or_path: str = field(default="Salesforce/blip-image-captioning-base") |
|
dtype: str = field(default="float16") |
|
use_vcot: bool = field(default=False) |
|
|
|
|
|
@dataclass |
|
class SCAModelBaseArguments(ModelArguments): |
|
model_name_or_path: Optional[str] = field(default=None) |
|
sam_model_name_or_path: str = field(default="facebook/sam-vit-huge") |
|
lm_head_model_name_or_path: str = field(default="gpt2") |
|
additional_num_hidden_layers: int = field(default=2) |
|
|
|
|
|
@dataclass |
|
class SCAModelArguments(SCAModelBaseArguments): |
|
num_caption_tokens: int = field(default=1) |
|
|
|
|
|
@dataclass |
|
class SCADirectDecodingModelArguments(SCAModelBaseArguments): |
|
pass |
|
|
|
|
|
@dataclass |
|
class SCAMultitaskModelArguments(SCAModelBaseArguments): |
|
num_caption_tokens: int = field(default=1) |
|
num_task_tokens: int = field(default=6) |
|
|
|
|
|
@dataclass |
|
class ScaMultitaskV2ModelArguments(SCAModelBaseArguments): |
|
num_caption_tokens: int = field(default=1) |
|
num_task_tokens: int = field(default=6) |
|
num_caption_heads: int = field(default=1) |
|
|
|
|
|
@dataclass |
|
class SCAMultitaskSplitMixerModelArguments(SCAModelBaseArguments): |
|
num_caption_tokens: int = field(default=1) |
|
num_task_tokens: int = field(default=6) |
|
num_caption_heads: int = field(default=1) |
|
|
|
|
|
@dataclass |
|
class SCADirectDecodingV2ModelArguments(SCAModelBaseArguments): |
|
num_task_tokens: int = field(default=6) |
|
|
|
|
|
@dataclass |
|
class SCAMultitaskROIPoolModelArguments(SCAModelBaseArguments): |
|
num_task_tokens: int = field(default=6) |
|
vl_projector_type: str = field(default="linear") |
|
vl_projector_norm_type: str = field(default="none") |
|
|
|
|
|
@dataclass |
|
class ScaTimmMultitaskV2ModelArguments(SCAModelBaseArguments): |
|
timm_vision_name: str = field(default="vit_base_patch16_clip_224.openai") |
|
num_caption_tokens: int = field(default=1) |
|
num_task_tokens: int = field(default=6) |
|
num_caption_heads: int = field(default=1) |
|
|
|
|
|
@dataclass |
|
class DataArguments: |
|
_target_: str = "datasets.load_dataset" |
|
path: Optional[str] = field(default=None) |
|
name: Optional[str] = field(default=None) |
|
split: Optional[str] = field(default=None) |
|
cache_dir: str = field(default=".data.cache") |
|
streaming: bool = field(default=False) |
|
|
|
|
|
@dataclass |
|
class VGDenseCapDataArgument(DataArguments): |
|
path: str = field(default=osp.join(osp.dirname(__file__), "data", "data_scripts", "visual_genome.py")) |
|
name: str = "region_descriptions_v1.2.0" |
|
|
|
base_image_url: Optional[str] = field(default=None) |
|
base_annotation_url: Optional[str] = field(default=None) |
|
sas_key: Optional[str] = field(default=None) |
|
|
|
use_densecap_splits: bool = field(default=True) |
|
|
|
with_image: bool = field(default=True) |
|
|
|
def __post_init__(self): |
|
if self.base_image_url is None: |
|
raise ValueError( |
|
"base_image_url must be specified in VGDenseCapDataArgument, since VisualGenome is not public available." |
|
) |
|
if self.base_annotation_url is None: |
|
raise ValueError( |
|
"base_annotation_url must be specified in VGDenseCapDataArgument, since VisualGenome is not public available." |
|
) |
|
if self.sas_key is None: |
|
logger.warning("sas_key maybe be specified in VGDenseCapDataArgument, since we fetch data from Azure.") |
|
|
|
|
|
@dataclass |
|
class VGDenseCapLocalDataArgument(DataArguments): |
|
path: str = field( |
|
default=osp.join(osp.dirname(__file__), "data", "data_scripts", "visual_genome-densecap-local.py") |
|
) |
|
name: str = "densecap" |
|
|
|
with_image: bool = field(default=True) |
|
|
|
base_dir: Optional[str] = field(default=None) |
|
base_annotation_dir: Optional[str] = field(default=None) |
|
|
|
|
|
@dataclass |
|
class VGGRiTLocalDataArgument(DataArguments): |
|
path: str = field(default=osp.join(osp.dirname(__file__), "data", "data_scripts", "visual_genome-grit-local.py")) |
|
name: str = "grit" |
|
|
|
with_image: bool = field(default=True) |
|
|
|
base_dir: Optional[str] = field(default=None) |
|
base_annotation_dir: Optional[str] = field(default=None) |
|
|
|
|
|
@dataclass |
|
class RefCOCODataArgument(DataArguments): |
|
path: str = field(default=osp.join(osp.dirname(__file__), "data", "data_scripts", "refcoco.py")) |
|
name: str = "refcoco-unc" |
|
|
|
base_url: Optional[str] = field(default=None) |
|
sas_key: Optional[str] = field(default=None) |
|
|
|
with_image: bool = field(default=True) |
|
with_mask: bool = field( |
|
default=False |
|
) |
|
|
|
|
|
@dataclass |
|
class SA1BCapDataArgument(DataArguments): |
|
path: str = field(default=osp.join(osp.dirname(__file__), "data", "data_scripts", "sa1b_cap.py")) |
|
name: str = "mask_region_descriptions_v0.0.1" |
|
|
|
sa1b_tar_url: Optional[str] = field(default=None) |
|
sa1b_tar_template: Optional[str] = field(default=None) |
|
|
|
sa1b_annot_tsv_url: Optional[str] = field(default=None) |
|
sa1b_annot_template: Optional[str] = field(default=None) |
|
|
|
sa1b_cap_tsv_url: Optional[str] = field(default=None) |
|
sa1b_cap_template: Optional[str] = field(default=None) |
|
|
|
sa1b_filter_tsv_url: Optional[str] = field(default=None) |
|
sa1b_filter_template: Optional[str] = field(default=None) |
|
|
|
sa1b_file_range: Optional[str] = field( |
|
default=None, |
|
metadata={ |
|
"help": "We use `ast.literal_eval` to parse the Python object. We assume it is a list of int or a `range` object." |
|
}, |
|
) |
|
|
|
with_image: bool = field(default=True) |
|
with_mask: bool = field( |
|
default=False |
|
) |
|
|
|
|
|
@dataclass |
|
class COCOInstanceDataArgument(DataArguments): |
|
path: str = field(default=osp.join(osp.dirname(__file__), "data", "data_scripts", "coco_instance.py")) |
|
name: str = "2017" |
|
|
|
coco_zip_url: Optional[str] = field(default=None) |
|
coco_annotations_zip_url: Optional[str] = field(default=None) |
|
|
|
with_image: bool = field(default=True) |
|
with_mask: bool = field( |
|
default=False |
|
) |
|
|
|
task_type: str = field(default="recognition") |
|
|
|
|
|
@dataclass |
|
class COCOInstanceLocalDataArgument(COCOInstanceDataArgument): |
|
path: str = field(default=osp.join(osp.dirname(__file__), "data", "data_scripts", "coco_instance-local.py")) |
|
|
|
|
|
@dataclass |
|
class Objects365LocalDataArgument(DataArguments): |
|
path: str = field(default=osp.join(osp.dirname(__file__), "data", "data_scripts", "objects365-local.py")) |
|
name: str = "v2" |
|
|
|
objects365_base_dir: Optional[str] = field(default=None) |
|
objects365_base_annotations_dir: Optional[str] = field(default=None) |
|
|
|
with_image: bool = field(default=True) |
|
with_mask: bool = field( |
|
default=False |
|
) |
|
|
|
task_type: str = field(default="recognition") |
|
|
|
|
|
@dataclass |
|
class V3DetLocalDataArgument(DataArguments): |
|
path: str = field(default=osp.join(osp.dirname(__file__), "data", "data_scripts", "v3det-local.py")) |
|
name: str = "v1" |
|
|
|
v3det_base_dir: Optional[str] = field(default=None) |
|
v3det_base_annotations_dir: Optional[str] = field(default=None) |
|
|
|
with_image: bool = field(default=True) |
|
with_mask: bool = field( |
|
default=False |
|
) |
|
|
|
task_type: str = field(default="recognition") |
|
|
|
|
|
@dataclass |
|
class SBUPseudoRegionDataArgument(DataArguments): |
|
path: str = field(default=osp.join(osp.dirname(__file__), "data", "data_scripts", "sbu-pseudo_region.py")) |
|
name: str = "pseudo_region" |
|
|
|
base_dir: Optional[str] = field(default=None) |
|
base_annotations_dir: Optional[str] = field(default=None) |
|
|
|
with_image: bool = field(default=True) |
|
with_mask: bool = field(default=False) |
|
|
|
|
|
@dataclass |
|
class SBUPseudoRegionLocalDataArgument(SBUPseudoRegionDataArgument): |
|
path: str = field(default=osp.join(osp.dirname(__file__), "data", "data_scripts", "sbu-pseudo_region-local.py")) |
|
|
|
|
|
@dataclass |
|
class COCOCaptionPseudoRegion(DataArguments): |
|
path: str = field(default=osp.join(osp.dirname(__file__), "data", "data_scripts", "coco_caption-pseudo_region.py")) |
|
name: str = "2017" |
|
|
|
coco_zip_url: Optional[str] = field(default=None) |
|
coco_annotations_zip_url: Optional[str] = field(default=None) |
|
|
|
with_image: bool = field(default=True) |
|
with_mask: bool = field(default=False) |
|
|
|
|
|
@dataclass |
|
class WandbArguments: |
|
log: bool = field(default=True) |
|
project: Optional[str] = field(default="sca", metadata={"help": "wandb project"}) |
|
group: Optional[str] = field(default="debug", metadata={"help": "wandb group"}) |
|
name: Optional[str] = field(default="run", metadata={"help": "wandb run name"}) |
|
tags: Optional[List[str]] = field(default=None, metadata={"help": "wandb tags"}) |
|
resume: str = field(default="allow", metadata={"help": "wandb resume strategy"}) |
|
id: Optional[str] = field(default=None, metadata={"help": "wandb run id"}) |
|
|
|
|
|
@dataclass |
|
class DataTransformsArguments: |
|
""" |
|
NOTE: used to control large-scale jittering data augmentation. |
|
""" |
|
|
|
min_scale: float = 0.1 |
|
max_scale: float = 2.0 |
|
image_size: int = 1024 |
|
|
|
|
|
@dataclass |
|
class DataCollatorClass: |
|
use_instruction: bool = field(default=False) |
|
|
|
instruction_mapping_json: Optional[str] = field(default=None) |
|
|
|
|
|
|
|
defaults = [{"wandb": "base_wandb"}] |
|
|
|
@dataclass |
|
class M3D2DLocalDataArgument(DataArguments): |
|
path: str = field(default=osp.join(osp.dirname(__file__), "data", "data_scripts", "m3d_2d.py")) |
|
name: str = "custom" |
|
data_dir: Optional[str] = field(default=None) |
|
|
|
with_image: bool = field(default=True) |
|
|
|
task_type: str = field(default="recognition") |
|
|
|
@dataclass |
|
class Arguments: |
|
defaults: List[Any] = field(default_factory=lambda: defaults) |
|
|
|
training: SCASeq2SeqTrainingArguments = field(default_factory=lambda: SCASeq2SeqTrainingArguments(output_dir="?")) |
|
|
|
|
|
|
|
|
|
train_data: List[str] = field(default_factory=list) |
|
train_data_interleave_probabilities: Optional[List[float]] = field(default=None) |
|
train_data_overrides: List[str] = field( |
|
default_factory=list, |
|
metadata={"help": "overrides for train data. \"train_data_overrides='[data.with_image\=False]'\""}, |
|
) |
|
|
|
eval_data: List[str] = field(default_factory=list) |
|
eval_data_overrides: List[str] = field( |
|
default_factory=list, |
|
metadata={"help": "overrides for eval data. \"eval_data_overrides='[data.with_image\=False]'\""}, |
|
) |
|
|
|
model: ModelArguments = field(default_factory=ModelArguments) |
|
wandb: WandbArguments = field(default_factory=WandbArguments) |
|
|
|
|
|
|
|
data_transforms: Optional[DataTransformsArguments] = field(default=None) |
|
|
|
|
|
|
|
data_collator: DataCollatorClass = field(default_factory=DataCollatorClass) |
|
|
|
|
|
cs = ConfigStore.instance() |
|
cs.store(name="base_config", node=Arguments) |
|
|
|
cs.store(group="data", name="base_vg_densecap", node=VGDenseCapDataArgument) |
|
cs.store(group="data", name="base_vg_densecap_local", node=VGDenseCapLocalDataArgument) |
|
cs.store(group="data", name="base_vg_grit_local", node=VGGRiTLocalDataArgument) |
|
cs.store(group="data", name="base_refcoco", node=RefCOCODataArgument) |
|
cs.store(group="data", name="base_sa1b_cap", node=SA1BCapDataArgument) |
|
cs.store(group="data", name="base_coco_instance", node=COCOInstanceDataArgument) |
|
cs.store(group="data", name="base_coco_instance_local", node=COCOInstanceLocalDataArgument) |
|
cs.store(group="data", name="base_objects365_local", node=Objects365LocalDataArgument) |
|
cs.store(group="data", name="base_v3det_local", node=V3DetLocalDataArgument) |
|
cs.store(group="data", name="base_sbu_pseudo_region", node=SBUPseudoRegionDataArgument) |
|
cs.store(group="data", name="base_sbu_pseudo_region_local", node=SBUPseudoRegionLocalDataArgument) |
|
cs.store(group="data", name="base_coco_caption_pseudo_region", node=COCOCaptionPseudoRegion) |
|
cs.store(group="data", name="base_m3d_2d", node=M3D2DLocalDataArgument) |
|
|
|
|
|
cs.store(group="model", name="base_sam_captioner", node=SAMCaptionerModelArguments) |
|
cs.store(group="model", name="base_sca", node=SCAModelArguments) |
|
cs.store(group="model", name="base_sca_direct_decoding", node=SCADirectDecodingModelArguments) |
|
cs.store(group="model", name="base_sca_multitask", node=SCAMultitaskModelArguments) |
|
cs.store(group="model", name="base_sca_multitask_v2", node=ScaMultitaskV2ModelArguments) |
|
cs.store(group="model", name="base_sca_multitask_split_mixer", node=SCAMultitaskSplitMixerModelArguments) |
|
cs.store(group="model", name="base_sca_direct_decoding_v2", node=SCADirectDecodingV2ModelArguments) |
|
cs.store(group="model", name="base_sca_multitask_roi_pool", node=SCAMultitaskROIPoolModelArguments) |
|
cs.store(group="model", name="base_sca_timm_multitask_v2", node=ScaTimmMultitaskV2ModelArguments) |
|
|
|
|
|
cs.store(group="wandb", name="base_wandb", node=WandbArguments) |
|
|
|
cs.store(group="data_transforms", name="base_data_transforms", node=DataTransformsArguments) |
|
|
|
cs.store(group="data_collator", name="base_data_collator", node=DataCollatorClass) |
|
|
|
|
|
def global_setup( |
|
args: DictConfig, |
|
) -> Tuple[Arguments, SCASeq2SeqTrainingArguments, ModelArguments]: |
|
"""Global setup of arguments.""" |
|
if args.training.output_log_dir is not None: |
|
output_log_dir = args.training.output_log_dir |
|
if not osp.exists(output_log_dir): |
|
os.makedirs(output_log_dir) |
|
|
|
|
|
|
|
for handler in logging.root.handlers: |
|
if isinstance(handler, logging.FileHandler): |
|
file_path = handler.baseFilename |
|
file_name = osp.basename(file_path) |
|
external_file_path = osp.join(output_log_dir, file_name) |
|
logging.root.addHandler(logging.FileHandler(external_file_path)) |
|
logger.info(f"Add external file handler to {external_file_path}") |
|
break |
|
|
|
hostname = socket.gethostname() |
|
logger.info(f"Running on {hostname}") |
|
|
|
|
|
|
|
del args.training._n_gpu |
|
|
|
args.training._run_post_init = True |
|
|
|
if args.training.do_eval == False and args.training.do_train == False: |
|
args.training.evaluation_strategy = "no" |
|
args.training.load_best_model_at_end = False |
|
|
|
training_args = OmegaConf.to_object(args.training) |
|
model_args = OmegaConf.to_object(args.model) |
|
|
|
if ( |
|
isinstance(model_args, (SCAModelArguments, SCADirectDecodingModelArguments)) |
|
and args.model.model_name_or_path is None |
|
): |
|
|
|
|
|
raise ValueError(f"{type(model_args)} is not supported in model cfg name.") |
|
|
|
log_level = training_args.get_process_log_level() |
|
logger.setLevel(log_level) |
|
datasets.utils.logging.set_verbosity(log_level) |
|
transformers.utils.logging.set_verbosity(log_level) |
|
transformers.utils.logging.enable_default_handler() |
|
transformers.utils.logging.enable_explicit_format() |
|
|
|
|
|
logger.warning( |
|
f"Process rank: {training_args.local_rank}, device: {training_args.device}," |
|
f" log_level: {log_level} n_gpu: {training_args.n_gpu}" |
|
f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits" |
|
f" training: {training_args.fp16}, bf16 training: {training_args.bf16}" |
|
) |
|
logger.debug(f"Training/evaluation parameters {training_args}") |
|
|
|
return args, training_args, model_args |
|
|