|
import logging |
|
import os |
|
|
|
import hydra |
|
from hydra.utils import instantiate |
|
from datasets import Dataset, load_dataset, IterableDataset, concatenate_datasets, interleave_datasets |
|
from omegaconf import DictConfig, OmegaConf |
|
from src.data.transforms import SamCaptionerDataTransform, SCADataTransform |
|
from src.data.collator import SamCaptionerDataCollator, SCADataCollator |
|
from src.arguments import ( |
|
Arguments, |
|
global_setup, |
|
SAMCaptionerModelArguments, |
|
SCAModelBaseArguments, |
|
SCAModelArguments, |
|
SCADirectDecodingModelArguments, |
|
SCAMultitaskModelArguments, |
|
SCAMultitaskSplitMixerModelArguments, |
|
ScaMultitaskV2ModelArguments, |
|
VGDenseCapDataArgument, |
|
RefCOCODataArgument, |
|
SA1BCapDataArgument, |
|
COCOInstanceDataArgument, |
|
SCADirectDecodingV2ModelArguments, |
|
SCAMultitaskROIPoolModelArguments, |
|
ScaTimmMultitaskV2ModelArguments, |
|
) |
|
from src.models.sam_captioner import SAMCaptionerConfig, SAMCaptionerModel, SAMCaptionerProcessor |
|
from src.sca_seq2seq_trainer import SCASeq2SeqTrainer, get_parameter_by_name, SAVING_FINISHED_FLAG |
|
from src.models.sca import ( |
|
ScaModel, |
|
ScaConfig, |
|
ScaProcessor, |
|
ScaDirectDecodingModel, |
|
ScaMultitaskModel, |
|
ScaMultitaskSplitMixerModel, |
|
ScaMultitaskV2Model, |
|
ScaDirectDecodingV2Model, |
|
ScaMultitaskROIPoolModel, |
|
ScaTimmMultitaskV2Model, |
|
) |
|
from src.integrations import CustomWandbCallBack, EvaluateFirstStepCallback, LoggerCallback, EvalLossCallback |
|
import src.models.sca |
|
import src.utils |
|
|
|
from transformers.trainer_utils import _re_checkpoint |
|
from transformers import set_seed |
|
import json |
|
import dotenv |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
def get_last_checkpoint(folder): |
|
content = os.listdir(folder) |
|
checkpoints = [ |
|
path |
|
for path in content |
|
if _re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path)) |
|
] |
|
if len(checkpoints) == 0: |
|
logger.warning(f"No checkpoint found in {folder}, but we got: {content}") |
|
return |
|
checkpoints = sorted(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0]), reverse=True) |
|
for ckeckpoint in checkpoints: |
|
|
|
if os.path.isfile(os.path.join(folder, ckeckpoint, SAVING_FINISHED_FLAG)): |
|
return os.path.join(folder, ckeckpoint) |
|
else: |
|
logger.warning(f"Checkpoint {os.path.join(folder, ckeckpoint)} does not have {SAVING_FINISHED_FLAG}, skip") |
|
return |
|
|
|
|
|
@hydra.main(version_base="1.3", config_path="conf", config_name="conf") |
|
def main(args: DictConfig) -> None: |
|
|
|
|
|
logger.info(OmegaConf.to_yaml(args)) |
|
args, training_args, model_args = global_setup(args) |
|
|
|
|
|
last_checkpoint = None |
|
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: |
|
last_checkpoint = get_last_checkpoint(training_args.output_dir) |
|
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: |
|
logger.warning( |
|
f"Output directory ({training_args.output_dir}) already exists and is not empty. " |
|
"There is no checkpoint in the directory. Or we can resume from `resume_from_checkpoint`." |
|
) |
|
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: |
|
logger.info( |
|
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " |
|
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch." |
|
) |
|
|
|
|
|
set_seed(args.training.seed) |
|
|
|
|
|
train_dataset, eval_dataset = prepare_datasets(args) |
|
|
|
|
|
logger.info(f"Try to load sas_key from .env file: {dotenv.load_dotenv('.env')}.") |
|
use_auth_token = os.getenv("USE_AUTH_TOKEN", False) |
|
|
|
processor = prepare_processor(model_args, use_auth_token) |
|
|
|
train_dataset, eval_dataset = prepare_data_transform( |
|
training_args, model_args, train_dataset, eval_dataset, processor |
|
) |
|
|
|
collate_fn = prepare_collate_fn(training_args, model_args, processor) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
compute_metrics = training_args.compute_metrics |
|
if compute_metrics is not True: |
|
|
|
|
|
compute_metrics = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = prepare_model(model_args, use_auth_token) |
|
if hasattr(model, "language_model") and model.language_model.config.bos_token_id is None: |
|
model.language_model.config.bos_token_id = processor.tokenizer.bos_token_id |
|
logger.warning(f"Set bos_token_id in language_model to {processor.tokenizer.bos_token_id}") |
|
if hasattr(model, "language_model") and model.language_model.config.eos_token_id is None: |
|
model.language_model.config.eos_token_id = processor.tokenizer.eos_token_id |
|
logger.warning(f"Set eos_token_id in language_model to {processor.tokenizer.eos_token_id}") |
|
|
|
prepare_model_trainable_parameters(model, args) |
|
|
|
|
|
custom_callbacks = [LoggerCallback(), EvalLossCallback()] |
|
if args.wandb.log is True: |
|
custom_callbacks.append(CustomWandbCallBack(args)) |
|
if training_args.evaluate_before_train: |
|
custom_callbacks.append(EvaluateFirstStepCallback()) |
|
|
|
trainer = SCASeq2SeqTrainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_dataset if training_args.do_train else None, |
|
eval_dataset=eval_dataset if training_args.do_eval or training_args.do_train else None, |
|
compute_metrics=compute_metrics, |
|
data_collator=collate_fn, |
|
tokenizer=processor.tokenizer, |
|
callbacks=custom_callbacks, |
|
) |
|
|
|
|
|
if training_args.do_train: |
|
checkpoint = None |
|
if training_args.resume_from_checkpoint is not None: |
|
checkpoint = training_args.resume_from_checkpoint |
|
elif last_checkpoint is not None: |
|
checkpoint = last_checkpoint |
|
train_result = trainer.train(resume_from_checkpoint=checkpoint) |
|
|
|
trainer.log_metrics("train", train_result.metrics) |
|
trainer.save_metrics("train", train_result.metrics) |
|
trainer.save_state() |
|
|
|
|
|
if training_args.do_eval: |
|
for eval_dataset_k, eval_dataset_v in eval_dataset.items(): |
|
metrics = trainer.evaluate(eval_dataset_v, metric_key_prefix=eval_dataset_k) |
|
trainer.log_metrics("eval", metrics) |
|
trainer.save_metrics("eval", metrics) |
|
|
|
if training_args.do_inference: |
|
for eval_dataset_k, eval_dataset_v in eval_dataset.items(): |
|
trainer.inference(eval_dataset_v, metric_key_prefix=eval_dataset_k) |
|
|
|
|
|
def prepare_collate_fn(training_args, model_args, processor): |
|
DataCollatorClass = None |
|
if isinstance(model_args, SAMCaptionerModelArguments): |
|
DataCollatorClass = SamCaptionerDataCollator |
|
elif isinstance(model_args, SCAModelBaseArguments): |
|
DataCollatorClass = SCADataCollator |
|
collate_fn = DataCollatorClass(processor.tokenizer) |
|
return collate_fn |
|
|
|
|
|
def prepare_data_transform(training_args, model_args, train_dataset, eval_dataset, processor): |
|
DataTransformClass = None |
|
if isinstance(model_args, SAMCaptionerModelArguments): |
|
DataTransformClass = SamCaptionerDataTransform |
|
elif isinstance(model_args, SCAModelBaseArguments): |
|
DataTransformClass = SCADataTransform |
|
if training_args.do_train: |
|
if train_dataset is None: |
|
raise ValueError("train_dataset must be provided if do_train is True") |
|
|
|
num_masks_per_sample = training_args.num_masks_per_sample |
|
if num_masks_per_sample is None: |
|
num_masks_per_sample = 64 |
|
logger.info(f"num_masks_per_sample not provided, defaulting to {num_masks_per_sample}") |
|
|
|
data_transforms = training_args.data_transforms |
|
|
|
train_transforms = DataTransformClass( |
|
processor.sam_processor, processor.tokenizer, "train", num_masks_per_sample, data_transforms |
|
) |
|
|
|
if isinstance(train_dataset, Dataset) and training_args.max_train_samples is not None: |
|
train_dataset = train_dataset.shuffle(seed=training_args.seed).select( |
|
range(training_args.max_train_samples) |
|
) |
|
|
|
if isinstance(train_dataset, Dataset): |
|
train_dataset = train_dataset.with_transform(train_transforms) |
|
elif isinstance(train_dataset, IterableDataset): |
|
train_dataset = train_dataset.map( |
|
train_transforms, batched=True, batch_size=training_args.per_device_train_batch_size |
|
) |
|
else: |
|
raise ValueError(f"dataset must be one of [Dataset, IterableDataset], got {type(train_dataset)}") |
|
else: |
|
logger.warning("do_train is False, so we do not apply data augmentation to train_dataset") |
|
|
|
if training_args.do_eval or training_args.do_inference or training_args.do_train: |
|
if eval_dataset is None: |
|
raise ValueError("eval_dataset must be provided if do_eval or do_inference is True") |
|
|
|
eval_transforms = DataTransformClass(processor.sam_processor, processor.tokenizer, "inference") |
|
for eval_dataset_k, eval_dataset_v in eval_dataset.items(): |
|
if isinstance(eval_dataset_v, Dataset) and training_args.max_eval_samples is not None: |
|
eval_dataset_v = eval_dataset_v.select(range(training_args.max_eval_samples)) |
|
|
|
if isinstance(eval_dataset_v, Dataset): |
|
eval_dataset_v = eval_dataset_v.with_transform(eval_transforms) |
|
elif isinstance(eval_dataset_v, IterableDataset): |
|
eval_dataset_v = eval_dataset_v.map( |
|
eval_transforms, batched=True, batch_size=training_args.per_device_eval_batch_size |
|
) |
|
else: |
|
raise ValueError(f"dataset must be one of [Dataset, IterableDataset], got {type(eval_dataset_v)}") |
|
eval_dataset[eval_dataset_k] = eval_dataset_v |
|
else: |
|
logger.warning( |
|
"do_eval and do_inference and do_train are False, so we do not apply data augmentation to eval_dataset" |
|
) |
|
return train_dataset, eval_dataset |
|
|
|
|
|
def prepare_model_trainable_parameters(model, args): |
|
trainable_params = args.training.trainable_params |
|
if trainable_params is None: |
|
logger.info("trainable_params is not provided, defaulting to the `config_parameters` method of the model.") |
|
return |
|
|
|
for param in model.parameters(): |
|
param.requires_grad = False |
|
|
|
logger.info(f"Config trainable_params: {trainable_params}") |
|
for param_name in trainable_params: |
|
param = get_parameter_by_name(model, param_name) |
|
for _param in param.parameters(): |
|
_param.requires_grad = True |
|
logger.info(f"Set {param_name} to trainable") |
|
|
|
|
|
def prepare_model(model_args, use_auth_token=False): |
|
if isinstance(model_args, SAMCaptionerModelArguments): |
|
model_args: SAMCaptionerModelArguments |
|
model = SAMCaptionerModel.from_sam_captioner_pretrained( |
|
model_args.sam_model_name_or_path, |
|
model_args.captioner_model_name_or_path, |
|
cache_dir=model_args.cache_dir, |
|
use_auth_token=use_auth_token, |
|
trust_remote_code=True, |
|
dtype=model_args.dtype, |
|
use_vcot=model_args.use_vcot, |
|
) |
|
elif isinstance(model_args, SCAModelBaseArguments): |
|
model_args: SCAModelBaseArguments |
|
if model_args.model_name_or_path is None: |
|
if model_args.sam_model_name_or_path is None: |
|
raise ValueError( |
|
"model_args.sam_model_name_or_path must be specified in SCAModelBaseArguments if model_args.model_name_or_path is None. " |
|
"Since we are not loading from a existing sca model." |
|
) |
|
if model_args.lm_head_model_name_or_path is None: |
|
raise ValueError( |
|
"model_args.lm_head_model_name_or_path must be specified in SCAModelBaseArguments if model_args.model_name_or_path is None. " |
|
"Since we are not loading from a existing sca model." |
|
) |
|
|
|
if isinstance(model_args, SCAModelArguments): |
|
model = ScaModel.from_sam_text_pretrained( |
|
model_args.sam_model_name_or_path, |
|
model_args.lm_head_model_name_or_path, |
|
model_args.additional_num_hidden_layers, |
|
model_args.num_caption_tokens, |
|
cache_dir=model_args.cache_dir, |
|
use_auth_token=use_auth_token, |
|
trust_remote_code=True, |
|
) |
|
elif isinstance(model_args, SCADirectDecodingModelArguments): |
|
model = ScaDirectDecodingModel.from_sam_text_pretrained( |
|
model_args.sam_model_name_or_path, |
|
model_args.lm_head_model_name_or_path, |
|
model_args.additional_num_hidden_layers, |
|
cache_dir=model_args.cache_dir, |
|
use_auth_token=use_auth_token, |
|
trust_remote_code=True, |
|
) |
|
elif isinstance(model_args, SCAMultitaskModelArguments): |
|
model = ScaMultitaskModel.from_sam_text_pretrained( |
|
model_args.sam_model_name_or_path, |
|
model_args.lm_head_model_name_or_path, |
|
model_args.additional_num_hidden_layers, |
|
model_args.num_caption_tokens, |
|
model_args.num_task_tokens, |
|
cache_dir=model_args.cache_dir, |
|
use_auth_token=use_auth_token, |
|
trust_remote_code=True, |
|
) |
|
elif isinstance(model_args, ScaMultitaskV2ModelArguments): |
|
model = ScaMultitaskV2Model.from_sam_text_pretrained( |
|
model_args.sam_model_name_or_path, |
|
model_args.lm_head_model_name_or_path, |
|
model_args.additional_num_hidden_layers, |
|
model_args.num_caption_tokens, |
|
model_args.num_task_tokens, |
|
model_args.num_caption_heads, |
|
cache_dir=model_args.cache_dir, |
|
use_auth_token=use_auth_token, |
|
trust_remote_code=True, |
|
) |
|
elif isinstance(model_args, SCAMultitaskSplitMixerModelArguments): |
|
model = ScaMultitaskSplitMixerModel.from_sam_text_pretrained( |
|
model_args.sam_model_name_or_path, |
|
model_args.lm_head_model_name_or_path, |
|
model_args.additional_num_hidden_layers, |
|
model_args.num_caption_tokens, |
|
model_args.num_task_tokens, |
|
model_args.num_caption_heads, |
|
cache_dir=model_args.cache_dir, |
|
use_auth_token=use_auth_token, |
|
trust_remote_code=True, |
|
) |
|
elif isinstance(model_args, SCADirectDecodingV2ModelArguments): |
|
model = ScaDirectDecodingV2Model.from_sam_text_pretrained( |
|
model_args.sam_model_name_or_path, |
|
model_args.lm_head_model_name_or_path, |
|
model_args.additional_num_hidden_layers, |
|
model_args.num_task_tokens, |
|
cache_dir=model_args.cache_dir, |
|
use_auth_token=use_auth_token, |
|
trust_remote_code=True, |
|
) |
|
elif isinstance(model_args, SCAMultitaskROIPoolModelArguments): |
|
model = ScaMultitaskROIPoolModel.from_sam_text_pretrained( |
|
model_args.sam_model_name_or_path, |
|
model_args.lm_head_model_name_or_path, |
|
num_task_tokens=model_args.num_task_tokens, |
|
vl_projector_type=model_args.vl_projector_type, |
|
vl_projector_norm_type=model_args.vl_projector_norm_type, |
|
cache_dir=model_args.cache_dir, |
|
use_auth_token=use_auth_token, |
|
trust_remote_code=True, |
|
) |
|
elif isinstance(model_args, ScaTimmMultitaskV2ModelArguments): |
|
timm_vision_name = getattr(model_args, "timm_vision_name", None) |
|
logger.info(f"timm_vision_name: {timm_vision_name}; sam path: {model_args.sam_model_name_or_path}") |
|
model = ScaTimmMultitaskV2Model.from_sam_timm_text_pretrained( |
|
timm_vision_name, |
|
model_args.sam_model_name_or_path, |
|
model_args.lm_head_model_name_or_path, |
|
model_args.additional_num_hidden_layers, |
|
model_args.num_caption_tokens, |
|
model_args.num_task_tokens, |
|
model_args.num_caption_heads, |
|
cache_dir=model_args.cache_dir, |
|
use_auth_token=use_auth_token, |
|
trust_remote_code=True, |
|
) |
|
else: |
|
raise ValueError( |
|
f"model_args must be one of [SCAModelArguments, SCADirectDecodingModelArguments, SCAMultitaskModelArguments]" |
|
) |
|
logger.info( |
|
f"Initalized sca model from sam model {model_args.sam_model_name_or_path} and lm head model {model_args.lm_head_model_name_or_path}" |
|
) |
|
else: |
|
|
|
model_config_json_path = os.path.join(model_args.model_name_or_path, "config.json") |
|
with open(model_config_json_path, "r") as f: |
|
model_config = json.load(f) |
|
if len(model_config["architectures"]) > 1: |
|
raise ValueError(f"Only support one architecture in model_config, got {model_config['architectures']}") |
|
architecture = model_config["architectures"][0] |
|
|
|
architecture_class = getattr(src.models.sca, architecture) |
|
model = architecture_class.from_pretrained(model_args.model_name_or_path) |
|
logger.info(f"Loaded sca model from {model_args.model_name_or_path}") |
|
else: |
|
raise ValueError( |
|
f"model_args must be one of [SAMCaptionerModelArguments, SCAModelBaseArguments], got {model_args}" |
|
) |
|
|
|
if ( |
|
hasattr(model_args, "lm_head_model_name_or_path") |
|
and model_args.lm_head_model_name_or_path == "microsoft/phi-2" |
|
): |
|
|
|
|
|
logger.warning("phi-2 cannot take in input_embeds, so we need to add it.") |
|
|
|
import types |
|
|
|
def phi_forward_updated( |
|
self, |
|
input_ids=None, |
|
inputs_embeds=None, |
|
past_key_values=None, |
|
attention_mask=None, |
|
): |
|
if input_ids is not None and inputs_embeds is not None: |
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
|
if input_ids is not None: |
|
hidden_states = self.embd(input_ids) |
|
elif inputs_embeds is not None: |
|
hidden_states = inputs_embeds |
|
else: |
|
raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
for layer in self.h: |
|
hidden_states = layer( |
|
hidden_states, |
|
past_key_values=past_key_values, |
|
attention_mask=attention_mask, |
|
) |
|
|
|
return hidden_states |
|
|
|
|
|
|
|
model.language_model.transformer.forward = types.MethodType( |
|
phi_forward_updated, model.language_model.transformer |
|
) |
|
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
|
def phi_forinput_ids_causal_lm_forward_updated( |
|
self, |
|
input_ids=None, |
|
inputs_embeds=None, |
|
past_key_values=None, |
|
attention_mask=None, |
|
labels=None, |
|
**kwargs, |
|
): |
|
hidden_states = self.transformer( |
|
input_ids, inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask |
|
) |
|
lm_logits = self.lm_head(hidden_states) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss = self.loss(lm_logits, labels) |
|
|
|
return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values) |
|
|
|
|
|
|
|
model.language_model.forward = types.MethodType( |
|
phi_forinput_ids_causal_lm_forward_updated, model.language_model |
|
) |
|
|
|
import torch |
|
from dataclasses import dataclass, field |
|
from typing import Any, Dict |
|
|
|
@dataclass |
|
class InferenceParams: |
|
"""Inference parameters passed to model to efficiently calculate |
|
and store context during inference. |
|
Reference: |
|
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py. |
|
Args: |
|
max_seqlen: Maximum sequence length. |
|
max_batch_size: Maximum batch size. |
|
seqlen_offset: Sequence length offset. |
|
batch_size_offset: Batch size offset. |
|
key_value_memory_dict: Key value memory dictionary. |
|
lengths_per_sample: Lengths per sample. |
|
""" |
|
|
|
max_seqlen: int = field(metadata={"help": "Maximum sequence length."}) |
|
|
|
max_batch_size: int = field(metadata={"help": "Maximum batch size."}) |
|
|
|
seqlen_offset: int = field(default=0, metadata={"help": "Sequence length offset."}) |
|
|
|
batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."}) |
|
|
|
key_value_memory_dict: Dict[str, Any] = field( |
|
default_factory=dict, metadata={"help": "Key value memory dictionary."} |
|
) |
|
|
|
lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."}) |
|
|
|
def phi_prepare_inputs_for_generation( |
|
self, |
|
input_ids=None, |
|
inputs_embeds=None, |
|
past_key_values=None, |
|
attention_mask=None, |
|
**kwargs, |
|
): |
|
model_inputs = {} |
|
|
|
|
|
|
|
|
|
if inputs_embeds is not None and past_key_values is None: |
|
model_inputs["inputs_embeds"] = inputs_embeds |
|
|
|
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)): |
|
past_key_values = InferenceParams( |
|
max_seqlen=self.config.n_positions, |
|
max_batch_size=input_ids.shape[0], |
|
seqlen_offset=0, |
|
batch_size_offset=0, |
|
key_value_memory_dict={}, |
|
lengths_per_sample=None, |
|
) |
|
else: |
|
|
|
|
|
past_key_values.seqlen_offset = attention_mask.shape[1] - 1 |
|
input_ids = input_ids[:, -1].unsqueeze(-1) |
|
|
|
if "inputs_embeds" not in model_inputs: |
|
model_inputs["input_ids"] = input_ids |
|
|
|
model_inputs.update( |
|
{ |
|
"past_key_values": past_key_values, |
|
"attention_mask": attention_mask, |
|
} |
|
) |
|
|
|
return model_inputs |
|
|
|
|
|
|
|
model.language_model.prepare_inputs_for_generation = types.MethodType( |
|
phi_prepare_inputs_for_generation, model.language_model |
|
) |
|
|
|
logger.info(f"Model: {model.config}") |
|
return model |
|
|
|
|
|
def prepare_datasets(args): |
|
train_data = [] |
|
for train_data_config_name in args.train_data: |
|
cfg = hydra.compose(config_name=f"data/{train_data_config_name}", overrides=args.train_data_overrides) |
|
train_data.append(cfg.data) |
|
args.train_data = train_data |
|
|
|
|
|
if len(args.eval_data) > 1: |
|
logger.warning(f"We should only inference one dataset, got {args.eval_data}") |
|
eval_data = [] |
|
for eval_data_config_name in args.eval_data: |
|
cfg = hydra.compose(config_name=f"data/{eval_data_config_name}", overrides=args.eval_data_overrides) |
|
eval_data.append(cfg.data) |
|
|
|
train_dataset = [] |
|
for i, each_train_data in enumerate(train_data): |
|
|
|
each_train_data.split = "train" |
|
|
|
_train_dataset = instantiate(each_train_data) |
|
train_dataset.append(_train_dataset) |
|
logger.info(f"Train Dataset [{i}]: {each_train_data}\n{_train_dataset}") |
|
|
|
eval_dataset = {} |
|
for i, each_eval_data in enumerate(eval_data): |
|
|
|
|
|
if "visual_genome.py" in each_eval_data.path and getattr(each_eval_data, "use_densecap_splits", None) is True: |
|
logger.info("Using densecap splits in Visual Genome, using test split to eval") |
|
each_eval_data.split = "test" |
|
|
|
|
|
elif "refcoco.py" in each_eval_data.path: |
|
if each_eval_data.name.startswith("refcoco-") or each_eval_data.name.startswith("refcoco+-"): |
|
if each_eval_data.split is None or each_eval_data.split == "train": |
|
raise ValueError(f"refcoco{{,+}} must have split for eval. got {each_eval_data.split}") |
|
logger.info(f"Using refcoco{{,+}}: {each_eval_data.split} split to eval") |
|
elif each_eval_data.name.startswith("refcocog"): |
|
logger.info("Using refcocog val split to eval") |
|
each_eval_data.split = "validation" |
|
elif each_eval_data.name.startswith("refclef"): |
|
logger.info("Using refclef val split to eval") |
|
each_eval_data.split = "validation" |
|
|
|
|
|
elif "coco_instance.py" in each_eval_data.path or "coco_instance-local.py" in each_eval_data.path: |
|
logger.info("Using coco val split to eval") |
|
each_eval_data.split = "validation" |
|
|
|
elif "objects365-local.py" in each_eval_data.path: |
|
logger.info("Using objects365 (in fact, it is COCO) val split to eval") |
|
each_eval_data.split = "validation" |
|
|
|
elif "v3det-local.py" in each_eval_data.path: |
|
logger.info("Using v3det (in fact, it is COCO) val split to eval") |
|
each_eval_data.split = "validation" |
|
|
|
elif "sbu-pseudo_region-local.py" in each_eval_data.path or "sbu-pseudo_region.py" in each_eval_data.path: |
|
logger.info("Using sbu to eval, but it does not have test split, so we use train split") |
|
each_eval_data.split = "train" |
|
|
|
elif "coco_caption-pseudo_region.py" in each_eval_data.path: |
|
logger.info("Using coco_caption (in fact, it is COCO) val split to eval") |
|
each_eval_data.split = "validation" |
|
|
|
elif ( |
|
"visual_genome-densecap-local.py" in each_eval_data.path |
|
or "visual_genome-grit-local.py" in each_eval_data.path |
|
): |
|
logger.info(f"Using visual_genome (They are my custom splits for GRiT and Densecap) test split to eval") |
|
each_eval_data.split = "test" |
|
elif "m3d_2d.py" in each_eval_data.path: |
|
logger.info(f"Using m3d_2d val split to eval") |
|
each_eval_data.split = "validation" |
|
else: |
|
raise ValueError( |
|
f"Unknown dataset {each_eval_data.path}, we cannot determine the split for it. Please edit `src/train.py:prepare_datasets` to add the split for it." |
|
) |
|
print(f"each_eval_data{each_eval_data}") |
|
_eval_dataset = instantiate(each_eval_data) |
|
eval_dataset_name = _get_data_name(each_eval_data) |
|
eval_dataset[eval_dataset_name] = _eval_dataset |
|
logger.info(f"Eval Dataset [{i}]: {each_eval_data}\n{_eval_dataset}") |
|
args.eval_data = eval_data |
|
|
|
if args.train_data_interleave_probabilities is not None and len(train_dataset) != len( |
|
args.train_data_interleave_probabilities |
|
): |
|
raise ValueError( |
|
f"train_data_interleave_probabilities must have the same length as train_data, got {len(train_dataset)} and {len(args.train_data_interleave_probabilities)}" |
|
) |
|
|
|
if len(train_dataset) > 0: |
|
if args.train_data_interleave_probabilities is None: |
|
logger.warning( |
|
"train_data_interleave_probabilities is not provided, " |
|
"the resulting dataset will have max_length_datasets*nb_dataset samples. " |
|
"As we use `all_exhausted` stopping strategy which is a oversampling strategy." |
|
) |
|
else: |
|
if sum(args.train_data_interleave_probabilities) != 1.0: |
|
logger.info(f"Normalize train_data_interleave_probabilities to sum to 1.0") |
|
args.train_data_interleave_probabilities = [ |
|
each_prob / sum(args.train_data_interleave_probabilities) |
|
for each_prob in args.train_data_interleave_probabilities |
|
] |
|
logger.info(f"train_data_interleave_probabilities: {args.train_data_interleave_probabilities}") |
|
|
|
|
|
train_dataset = interleave_datasets( |
|
train_dataset, |
|
probabilities=args.train_data_interleave_probabilities, |
|
seed=args.training.seed, |
|
stopping_strategy="all_exhausted", |
|
) |
|
else: |
|
train_dataset = None |
|
|
|
logger.info(f"Train Dataset: {train_dataset}") |
|
logger.info(f"Eval Dataset: {eval_dataset}") |
|
return train_dataset, eval_dataset |
|
|
|
|
|
def _get_data_name(dataset_config_dict): |
|
|
|
path = dataset_config_dict.path |
|
path_name = os.path.splitext(os.path.basename(path))[0] |
|
name = dataset_config_dict.name |
|
split = dataset_config_dict.split |
|
return f"{path_name}-{name}-{split}" |
|
|
|
|
|
def prepare_processor(model_args, use_auth_token): |
|
if isinstance(model_args, SAMCaptionerModelArguments): |
|
processor = SAMCaptionerProcessor.from_sam_captioner_pretrained( |
|
model_args.sam_model_name_or_path, |
|
model_args.captioner_model_name_or_path, |
|
cache_dir=model_args.cache_dir, |
|
model_max_length=model_args.model_max_length, |
|
use_auth_token=use_auth_token, |
|
trust_remote_code=True, |
|
) |
|
|
|
|
|
elif isinstance(model_args, SCAModelBaseArguments): |
|
processor = ScaProcessor.from_sam_text_pretrained( |
|
model_args.sam_model_name_or_path, |
|
model_args.lm_head_model_name_or_path, |
|
cache_dir=model_args.cache_dir, |
|
model_max_length=model_args.model_max_length, |
|
use_auth_token=use_auth_token, |
|
trust_remote_code=True, |
|
) |
|
else: |
|
raise ValueError( |
|
f"model_args must be one of [SAMCaptionerModelArguments, SCAModelBaseArguments], got {type(model_args)}" |
|
) |
|
|
|
if processor.tokenizer.pad_token is None: |
|
if processor.tokenizer.eos_token is None: |
|
raise ValueError("tokenizer must have either eos_token") |
|
processor.tokenizer.pad_token = processor.tokenizer.eos_token |
|
|
|
return processor |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|