deepspeed / src /train.py
xingzhikb's picture
init
002bd9b
import logging
import os
import hydra
from hydra.utils import instantiate
from datasets import Dataset, load_dataset, IterableDataset, concatenate_datasets, interleave_datasets
from omegaconf import DictConfig, OmegaConf
from src.data.transforms import SamCaptionerDataTransform, SCADataTransform
from src.data.collator import SamCaptionerDataCollator, SCADataCollator
from src.arguments import (
Arguments,
global_setup,
SAMCaptionerModelArguments,
SCAModelBaseArguments,
SCAModelArguments,
SCADirectDecodingModelArguments,
SCAMultitaskModelArguments,
SCAMultitaskSplitMixerModelArguments,
ScaMultitaskV2ModelArguments,
VGDenseCapDataArgument,
RefCOCODataArgument,
SA1BCapDataArgument,
COCOInstanceDataArgument,
SCADirectDecodingV2ModelArguments,
SCAMultitaskROIPoolModelArguments,
ScaTimmMultitaskV2ModelArguments,
)
from src.models.sam_captioner import SAMCaptionerConfig, SAMCaptionerModel, SAMCaptionerProcessor
from src.sca_seq2seq_trainer import SCASeq2SeqTrainer, get_parameter_by_name, SAVING_FINISHED_FLAG
from src.models.sca import (
ScaModel,
ScaConfig,
ScaProcessor,
ScaDirectDecodingModel,
ScaMultitaskModel,
ScaMultitaskSplitMixerModel,
ScaMultitaskV2Model,
ScaDirectDecodingV2Model,
ScaMultitaskROIPoolModel,
ScaTimmMultitaskV2Model,
)
from src.integrations import CustomWandbCallBack, EvaluateFirstStepCallback, LoggerCallback, EvalLossCallback
import src.models.sca
import src.utils
from transformers.trainer_utils import _re_checkpoint
from transformers import set_seed
import json
import dotenv
logger = logging.getLogger(__name__)
# Copied from `transformers/trainer_utils.py`
def get_last_checkpoint(folder):
content = os.listdir(folder)
checkpoints = [
path
for path in content
if _re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path))
]
if len(checkpoints) == 0:
logger.warning(f"No checkpoint found in {folder}, but we got: {content}")
return
checkpoints = sorted(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0]), reverse=True)
for ckeckpoint in checkpoints:
# NOTE: it is possible for partial saving which cannot be read.
if os.path.isfile(os.path.join(folder, ckeckpoint, SAVING_FINISHED_FLAG)):
return os.path.join(folder, ckeckpoint)
else:
logger.warning(f"Checkpoint {os.path.join(folder, ckeckpoint)} does not have {SAVING_FINISHED_FLAG}, skip")
return
@hydra.main(version_base="1.3", config_path="conf", config_name="conf")
def main(args: DictConfig) -> None:
# NOTE(xiaoke): follow https://github.com/huggingface/transformers/blob/main/examples/pytorch/image-classification/run_image_classification.py
logger.info(OmegaConf.to_yaml(args))
args, training_args, model_args = global_setup(args)
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
logger.warning(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"There is no checkpoint in the directory. Or we can resume from `resume_from_checkpoint`."
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Set seed before initializing model.
set_seed(args.training.seed)
# Initialize our dataset and prepare it
train_dataset, eval_dataset = prepare_datasets(args)
# NOTE(xiaoke): load sas_key from .env for huggingface model downloading.
logger.info(f"Try to load sas_key from .env file: {dotenv.load_dotenv('.env')}.")
use_auth_token = os.getenv("USE_AUTH_TOKEN", False)
processor = prepare_processor(model_args, use_auth_token)
train_dataset, eval_dataset = prepare_data_transform(
training_args, model_args, train_dataset, eval_dataset, processor
)
collate_fn = prepare_collate_fn(training_args, model_args, processor)
# Load the accuracy metric from the datasets package
# metric = evaluate.load("accuracy")
# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
# predictions and label_ids field) and has to return a dictionary string to float.
# def compute_metrics(p):
# """Computes accuracy on a batch of predictions"""
# return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)
compute_metrics = training_args.compute_metrics
if compute_metrics is not True:
# NOTE: compute_metrics = None triggers the default `prediction_loss_only=True`
# NOTE: compute_metrics should be a function, but we define the function in the trainer, so we use bool here to indicate the usage.
compute_metrics = None
# config = AutoConfig.from_pretrained(
# model_args.config_name or model_args.model_name_or_path,
# num_labels=len(labels),
# label2id=label2id,
# id2label=id2label,
# finetuning_task="image-classification",
# cache_dir=model_args.cache_dir,
# revision=model_args.model_revision,
# use_auth_token=True if model_args.use_auth_token else None,
# )
# model = AutoModelForImageClassification.from_pretrained(
# model_args.model_name_or_path,
# from_tf=bool(".ckpt" in model_args.model_name_or_path),
# config=config,
# cache_dir=model_args.cache_dir,
# revision=model_args.model_revision,
# use_auth_token=True if model_args.use_auth_token else None,
# ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
# )
# image_processor = AutoImageProcessor.from_pretrained(
# model_args.image_processor_name or model_args.model_name_or_path,
# cache_dir=model_args.cache_dir,
# revision=model_args.model_revision,
# use_auth_token=True if model_args.use_auth_token else None,
# )
model = prepare_model(model_args, use_auth_token)
if hasattr(model, "language_model") and model.language_model.config.bos_token_id is None:
model.language_model.config.bos_token_id = processor.tokenizer.bos_token_id
logger.warning(f"Set bos_token_id in language_model to {processor.tokenizer.bos_token_id}")
if hasattr(model, "language_model") and model.language_model.config.eos_token_id is None:
model.language_model.config.eos_token_id = processor.tokenizer.eos_token_id
logger.warning(f"Set eos_token_id in language_model to {processor.tokenizer.eos_token_id}")
prepare_model_trainable_parameters(model, args)
# Initalize our trainer
custom_callbacks = [LoggerCallback(), EvalLossCallback()]
if args.wandb.log is True:
custom_callbacks.append(CustomWandbCallBack(args))
if training_args.evaluate_before_train:
custom_callbacks.append(EvaluateFirstStepCallback())
trainer = SCASeq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval or training_args.do_train else None,
compute_metrics=compute_metrics,
data_collator=collate_fn,
tokenizer=processor.tokenizer,
callbacks=custom_callbacks,
)
# Training
if training_args.do_train:
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
# trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
# Evaluation
if training_args.do_eval:
for eval_dataset_k, eval_dataset_v in eval_dataset.items():
metrics = trainer.evaluate(eval_dataset_v, metric_key_prefix=eval_dataset_k)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if training_args.do_inference:
for eval_dataset_k, eval_dataset_v in eval_dataset.items():
trainer.inference(eval_dataset_v, metric_key_prefix=eval_dataset_k)
def prepare_collate_fn(training_args, model_args, processor):
DataCollatorClass = None
if isinstance(model_args, SAMCaptionerModelArguments):
DataCollatorClass = SamCaptionerDataCollator
elif isinstance(model_args, SCAModelBaseArguments):
DataCollatorClass = SCADataCollator
collate_fn = DataCollatorClass(processor.tokenizer)
return collate_fn
def prepare_data_transform(training_args, model_args, train_dataset, eval_dataset, processor):
DataTransformClass = None
if isinstance(model_args, SAMCaptionerModelArguments):
DataTransformClass = SamCaptionerDataTransform
elif isinstance(model_args, SCAModelBaseArguments):
DataTransformClass = SCADataTransform
if training_args.do_train:
if train_dataset is None:
raise ValueError("train_dataset must be provided if do_train is True")
num_masks_per_sample = training_args.num_masks_per_sample
if num_masks_per_sample is None:
num_masks_per_sample = 64
logger.info(f"num_masks_per_sample not provided, defaulting to {num_masks_per_sample}")
data_transforms = training_args.data_transforms
train_transforms = DataTransformClass(
processor.sam_processor, processor.tokenizer, "train", num_masks_per_sample, data_transforms
)
if isinstance(train_dataset, Dataset) and training_args.max_train_samples is not None:
train_dataset = train_dataset.shuffle(seed=training_args.seed).select(
range(training_args.max_train_samples)
)
# Set the training transforms
if isinstance(train_dataset, Dataset):
train_dataset = train_dataset.with_transform(train_transforms)
elif isinstance(train_dataset, IterableDataset):
train_dataset = train_dataset.map(
train_transforms, batched=True, batch_size=training_args.per_device_train_batch_size
)
else:
raise ValueError(f"dataset must be one of [Dataset, IterableDataset], got {type(train_dataset)}")
else:
logger.warning("do_train is False, so we do not apply data augmentation to train_dataset")
if training_args.do_eval or training_args.do_inference or training_args.do_train:
if eval_dataset is None:
raise ValueError("eval_dataset must be provided if do_eval or do_inference is True")
eval_transforms = DataTransformClass(processor.sam_processor, processor.tokenizer, "inference")
for eval_dataset_k, eval_dataset_v in eval_dataset.items():
if isinstance(eval_dataset_v, Dataset) and training_args.max_eval_samples is not None:
eval_dataset_v = eval_dataset_v.select(range(training_args.max_eval_samples))
# Set the validation transforms
if isinstance(eval_dataset_v, Dataset):
eval_dataset_v = eval_dataset_v.with_transform(eval_transforms)
elif isinstance(eval_dataset_v, IterableDataset):
eval_dataset_v = eval_dataset_v.map(
eval_transforms, batched=True, batch_size=training_args.per_device_eval_batch_size
)
else:
raise ValueError(f"dataset must be one of [Dataset, IterableDataset], got {type(eval_dataset_v)}")
eval_dataset[eval_dataset_k] = eval_dataset_v
else:
logger.warning(
"do_eval and do_inference and do_train are False, so we do not apply data augmentation to eval_dataset"
)
return train_dataset, eval_dataset
def prepare_model_trainable_parameters(model, args):
trainable_params = args.training.trainable_params
if trainable_params is None:
logger.info("trainable_params is not provided, defaulting to the `config_parameters` method of the model.")
return
for param in model.parameters():
param.requires_grad = False
logger.info(f"Config trainable_params: {trainable_params}")
for param_name in trainable_params:
param = get_parameter_by_name(model, param_name)
for _param in param.parameters():
_param.requires_grad = True
logger.info(f"Set {param_name} to trainable")
def prepare_model(model_args, use_auth_token=False):
if isinstance(model_args, SAMCaptionerModelArguments):
model_args: SAMCaptionerModelArguments
model = SAMCaptionerModel.from_sam_captioner_pretrained(
model_args.sam_model_name_or_path,
model_args.captioner_model_name_or_path,
cache_dir=model_args.cache_dir,
use_auth_token=use_auth_token,
trust_remote_code=True,
dtype=model_args.dtype,
use_vcot=model_args.use_vcot,
)
elif isinstance(model_args, SCAModelBaseArguments):
model_args: SCAModelBaseArguments
if model_args.model_name_or_path is None:
if model_args.sam_model_name_or_path is None:
raise ValueError(
"model_args.sam_model_name_or_path must be specified in SCAModelBaseArguments if model_args.model_name_or_path is None. "
"Since we are not loading from a existing sca model."
)
if model_args.lm_head_model_name_or_path is None:
raise ValueError(
"model_args.lm_head_model_name_or_path must be specified in SCAModelBaseArguments if model_args.model_name_or_path is None. "
"Since we are not loading from a existing sca model."
)
# NOTE(xiaoke): Initalize different kinds of sca models
if isinstance(model_args, SCAModelArguments):
model = ScaModel.from_sam_text_pretrained(
model_args.sam_model_name_or_path,
model_args.lm_head_model_name_or_path,
model_args.additional_num_hidden_layers,
model_args.num_caption_tokens,
cache_dir=model_args.cache_dir,
use_auth_token=use_auth_token,
trust_remote_code=True,
)
elif isinstance(model_args, SCADirectDecodingModelArguments):
model = ScaDirectDecodingModel.from_sam_text_pretrained(
model_args.sam_model_name_or_path,
model_args.lm_head_model_name_or_path,
model_args.additional_num_hidden_layers,
cache_dir=model_args.cache_dir,
use_auth_token=use_auth_token,
trust_remote_code=True,
)
elif isinstance(model_args, SCAMultitaskModelArguments):
model = ScaMultitaskModel.from_sam_text_pretrained(
model_args.sam_model_name_or_path,
model_args.lm_head_model_name_or_path,
model_args.additional_num_hidden_layers,
model_args.num_caption_tokens,
model_args.num_task_tokens,
cache_dir=model_args.cache_dir,
use_auth_token=use_auth_token,
trust_remote_code=True,
)
elif isinstance(model_args, ScaMultitaskV2ModelArguments):
model = ScaMultitaskV2Model.from_sam_text_pretrained(
model_args.sam_model_name_or_path,
model_args.lm_head_model_name_or_path,
model_args.additional_num_hidden_layers,
model_args.num_caption_tokens,
model_args.num_task_tokens,
model_args.num_caption_heads,
cache_dir=model_args.cache_dir,
use_auth_token=use_auth_token,
trust_remote_code=True,
)
elif isinstance(model_args, SCAMultitaskSplitMixerModelArguments):
model = ScaMultitaskSplitMixerModel.from_sam_text_pretrained(
model_args.sam_model_name_or_path,
model_args.lm_head_model_name_or_path,
model_args.additional_num_hidden_layers,
model_args.num_caption_tokens,
model_args.num_task_tokens,
model_args.num_caption_heads,
cache_dir=model_args.cache_dir,
use_auth_token=use_auth_token,
trust_remote_code=True,
)
elif isinstance(model_args, SCADirectDecodingV2ModelArguments):
model = ScaDirectDecodingV2Model.from_sam_text_pretrained(
model_args.sam_model_name_or_path,
model_args.lm_head_model_name_or_path,
model_args.additional_num_hidden_layers,
model_args.num_task_tokens,
cache_dir=model_args.cache_dir,
use_auth_token=use_auth_token,
trust_remote_code=True,
)
elif isinstance(model_args, SCAMultitaskROIPoolModelArguments):
model = ScaMultitaskROIPoolModel.from_sam_text_pretrained(
model_args.sam_model_name_or_path,
model_args.lm_head_model_name_or_path,
num_task_tokens=model_args.num_task_tokens,
vl_projector_type=model_args.vl_projector_type,
vl_projector_norm_type=model_args.vl_projector_norm_type,
cache_dir=model_args.cache_dir,
use_auth_token=use_auth_token,
trust_remote_code=True,
)
elif isinstance(model_args, ScaTimmMultitaskV2ModelArguments):
timm_vision_name = getattr(model_args, "timm_vision_name", None)
logger.info(f"timm_vision_name: {timm_vision_name}; sam path: {model_args.sam_model_name_or_path}")
model = ScaTimmMultitaskV2Model.from_sam_timm_text_pretrained(
timm_vision_name,
model_args.sam_model_name_or_path,
model_args.lm_head_model_name_or_path,
model_args.additional_num_hidden_layers,
model_args.num_caption_tokens,
model_args.num_task_tokens,
model_args.num_caption_heads,
cache_dir=model_args.cache_dir,
use_auth_token=use_auth_token,
trust_remote_code=True,
)
else:
raise ValueError(
f"model_args must be one of [SCAModelArguments, SCADirectDecodingModelArguments, SCAMultitaskModelArguments]"
)
logger.info(
f"Initalized sca model from sam model {model_args.sam_model_name_or_path} and lm head model {model_args.lm_head_model_name_or_path}"
)
else:
# NOTE(xiaoke): load from existing sca series model for inference
model_config_json_path = os.path.join(model_args.model_name_or_path, "config.json")
with open(model_config_json_path, "r") as f:
model_config = json.load(f)
if len(model_config["architectures"]) > 1:
raise ValueError(f"Only support one architecture in model_config, got {model_config['architectures']}")
architecture = model_config["architectures"][0]
architecture_class = getattr(src.models.sca, architecture)
model = architecture_class.from_pretrained(model_args.model_name_or_path)
logger.info(f"Loaded sca model from {model_args.model_name_or_path}")
else:
raise ValueError(
f"model_args must be one of [SAMCaptionerModelArguments, SCAModelBaseArguments], got {model_args}"
)
if (
hasattr(model_args, "lm_head_model_name_or_path")
and model_args.lm_head_model_name_or_path == "microsoft/phi-2"
):
# NOTE: phi cannot take in input_embeds, so we need to add it.
# https://huggingface.co/microsoft/phi-2/blob/main/modeling_phi.py
logger.warning("phi-2 cannot take in input_embeds, so we need to add it.")
import types
def phi_forward_updated(
self,
input_ids=None,
inputs_embeds=None,
past_key_values=None,
attention_mask=None,
):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
if input_ids is not None:
hidden_states = self.embd(input_ids)
elif inputs_embeds is not None:
hidden_states = inputs_embeds
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
for layer in self.h:
hidden_states = layer(
hidden_states,
past_key_values=past_key_values,
attention_mask=attention_mask,
)
return hidden_states
# NOTE: replace the method on the fly. It's soooo good.
# https://stackoverflow.com/questions/52292599/can-i-replace-an-existing-method-of-an-object-in-python
model.language_model.transformer.forward = types.MethodType(
phi_forward_updated, model.language_model.transformer
)
from transformers.modeling_outputs import CausalLMOutputWithPast
def phi_forinput_ids_causal_lm_forward_updated(
self,
input_ids=None,
inputs_embeds=None,
past_key_values=None,
attention_mask=None,
labels=None,
**kwargs,
):
hidden_states = self.transformer(
input_ids, inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask
)
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
loss = self.loss(lm_logits, labels)
return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values)
# NOTE: replace the method on the fly. It's soooo good.
# https://stackoverflow.com/questions/52292599/can-i-replace-an-existing-method-of-an-object-in-python
model.language_model.forward = types.MethodType(
phi_forinput_ids_causal_lm_forward_updated, model.language_model
)
import torch
from dataclasses import dataclass, field
from typing import Any, Dict
@dataclass
class InferenceParams:
"""Inference parameters passed to model to efficiently calculate
and store context during inference.
Reference:
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
Args:
max_seqlen: Maximum sequence length.
max_batch_size: Maximum batch size.
seqlen_offset: Sequence length offset.
batch_size_offset: Batch size offset.
key_value_memory_dict: Key value memory dictionary.
lengths_per_sample: Lengths per sample.
"""
max_seqlen: int = field(metadata={"help": "Maximum sequence length."})
max_batch_size: int = field(metadata={"help": "Maximum batch size."})
seqlen_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
key_value_memory_dict: Dict[str, Any] = field(
default_factory=dict, metadata={"help": "Key value memory dictionary."}
)
lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."})
def phi_prepare_inputs_for_generation(
self,
input_ids=None,
inputs_embeds=None,
past_key_values=None,
attention_mask=None,
**kwargs,
):
model_inputs = {}
# NOTE: src/transformers/models/deprecated/open_llama/modeling_open_llama.py:prepare_inputs_for_generation
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
# Then the kvs are cached (by `past_key_values`) and subsequent steps will use.
# After the frist step, we only need to use `input_ids`
if inputs_embeds is not None and past_key_values is None:
model_inputs["inputs_embeds"] = inputs_embeds
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
past_key_values = InferenceParams(
max_seqlen=self.config.n_positions,
max_batch_size=input_ids.shape[0],
seqlen_offset=0,
batch_size_offset=0,
key_value_memory_dict={},
lengths_per_sample=None,
)
else:
# Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
# NOTE: if use inputs_embeds, we use attention_mask to get the seqlen_offset
past_key_values.seqlen_offset = attention_mask.shape[1] - 1
input_ids = input_ids[:, -1].unsqueeze(-1)
if "inputs_embeds" not in model_inputs:
model_inputs["input_ids"] = input_ids
model_inputs.update(
{
"past_key_values": past_key_values,
"attention_mask": attention_mask,
}
)
return model_inputs
# NOTE: replace the method on the fly. It's soooo good.
# https://stackoverflow.com/questions/52292599/can-i-replace-an-existing-method-of-an-object-in-python
model.language_model.prepare_inputs_for_generation = types.MethodType(
phi_prepare_inputs_for_generation, model.language_model
)
logger.info(f"Model: {model.config}")
return model
def prepare_datasets(args):
train_data = []
for train_data_config_name in args.train_data:
cfg = hydra.compose(config_name=f"data/{train_data_config_name}", overrides=args.train_data_overrides)
train_data.append(cfg.data)
args.train_data = train_data
# NOTE(xiaoke): We should only inference one eval dataset
if len(args.eval_data) > 1:
logger.warning(f"We should only inference one dataset, got {args.eval_data}")
eval_data = []
for eval_data_config_name in args.eval_data:
cfg = hydra.compose(config_name=f"data/{eval_data_config_name}", overrides=args.eval_data_overrides)
eval_data.append(cfg.data)
train_dataset = []
for i, each_train_data in enumerate(train_data):
# NOTE: add data `split` to each dataset
each_train_data.split = "train"
_train_dataset = instantiate(each_train_data)
train_dataset.append(_train_dataset)
logger.info(f"Train Dataset [{i}]: {each_train_data}\n{_train_dataset}")
eval_dataset = {}
for i, each_eval_data in enumerate(eval_data):
# NOTE: add data `split` to each dataset
# NOTE: visual genome has validation set, but we use test set for evaluation
if "visual_genome.py" in each_eval_data.path and getattr(each_eval_data, "use_densecap_splits", None) is True:
logger.info("Using densecap splits in Visual Genome, using test split to eval")
each_eval_data.split = "test"
# NOTE: refcoco has validation set, but we use test set for evaluation
elif "refcoco.py" in each_eval_data.path:
if each_eval_data.name.startswith("refcoco-") or each_eval_data.name.startswith("refcoco+-"):
if each_eval_data.split is None or each_eval_data.split == "train":
raise ValueError(f"refcoco{{,+}} must have split for eval. got {each_eval_data.split}")
logger.info(f"Using refcoco{{,+}}: {each_eval_data.split} split to eval")
elif each_eval_data.name.startswith("refcocog"):
logger.info("Using refcocog val split to eval")
each_eval_data.split = "validation"
elif each_eval_data.name.startswith("refclef"):
logger.info("Using refclef val split to eval")
each_eval_data.split = "validation"
# NOTE: coco has validation set, but it does not have test set.
elif "coco_instance.py" in each_eval_data.path or "coco_instance-local.py" in each_eval_data.path:
logger.info("Using coco val split to eval")
each_eval_data.split = "validation"
elif "objects365-local.py" in each_eval_data.path:
logger.info("Using objects365 (in fact, it is COCO) val split to eval")
each_eval_data.split = "validation"
elif "v3det-local.py" in each_eval_data.path:
logger.info("Using v3det (in fact, it is COCO) val split to eval")
each_eval_data.split = "validation"
elif "sbu-pseudo_region-local.py" in each_eval_data.path or "sbu-pseudo_region.py" in each_eval_data.path:
logger.info("Using sbu to eval, but it does not have test split, so we use train split")
each_eval_data.split = "train"
elif "coco_caption-pseudo_region.py" in each_eval_data.path:
logger.info("Using coco_caption (in fact, it is COCO) val split to eval")
each_eval_data.split = "validation"
elif (
"visual_genome-densecap-local.py" in each_eval_data.path
or "visual_genome-grit-local.py" in each_eval_data.path
):
logger.info(f"Using visual_genome (They are my custom splits for GRiT and Densecap) test split to eval")
each_eval_data.split = "test"
elif "m3d_2d.py" in each_eval_data.path:
logger.info(f"Using m3d_2d val split to eval")
each_eval_data.split = "validation"
else:
raise ValueError(
f"Unknown dataset {each_eval_data.path}, we cannot determine the split for it. Please edit `src/train.py:prepare_datasets` to add the split for it."
)
print(f"each_eval_data{each_eval_data}")
_eval_dataset = instantiate(each_eval_data)
eval_dataset_name = _get_data_name(each_eval_data)
eval_dataset[eval_dataset_name] = _eval_dataset
logger.info(f"Eval Dataset [{i}]: {each_eval_data}\n{_eval_dataset}")
args.eval_data = eval_data # NOTE: overwrite previous eval_data
if args.train_data_interleave_probabilities is not None and len(train_dataset) != len(
args.train_data_interleave_probabilities
):
raise ValueError(
f"train_data_interleave_probabilities must have the same length as train_data, got {len(train_dataset)} and {len(args.train_data_interleave_probabilities)}"
)
# NOTE(xiaoke): Expected a list of Dataset objects or a list of IterableDataset objects.
if len(train_dataset) > 0:
if args.train_data_interleave_probabilities is None:
logger.warning(
"train_data_interleave_probabilities is not provided, "
"the resulting dataset will have max_length_datasets*nb_dataset samples. "
"As we use `all_exhausted` stopping strategy which is a oversampling strategy."
)
else:
if sum(args.train_data_interleave_probabilities) != 1.0:
logger.info(f"Normalize train_data_interleave_probabilities to sum to 1.0")
args.train_data_interleave_probabilities = [
each_prob / sum(args.train_data_interleave_probabilities)
for each_prob in args.train_data_interleave_probabilities
]
logger.info(f"train_data_interleave_probabilities: {args.train_data_interleave_probabilities}")
# NOTE(xiaoke): Accourding to `datasets/src/datasets/arrow_dataset.py:_interleave_map_style_datasets:6079` and
# `Breadcrumbsdatasets/src/datasets/iterable_dataset.py:_interleave_iterable_datasets:2293`
train_dataset = interleave_datasets(
train_dataset,
probabilities=args.train_data_interleave_probabilities,
seed=args.training.seed,
stopping_strategy="all_exhausted",
)
else:
train_dataset = None
logger.info(f"Train Dataset: {train_dataset}")
logger.info(f"Eval Dataset: {eval_dataset}")
return train_dataset, eval_dataset
def _get_data_name(dataset_config_dict):
# NOTE: path is the path for data script
path = dataset_config_dict.path
path_name = os.path.splitext(os.path.basename(path))[0]
name = dataset_config_dict.name
split = dataset_config_dict.split
return f"{path_name}-{name}-{split}"
def prepare_processor(model_args, use_auth_token):
if isinstance(model_args, SAMCaptionerModelArguments):
processor = SAMCaptionerProcessor.from_sam_captioner_pretrained(
model_args.sam_model_name_or_path,
model_args.captioner_model_name_or_path,
cache_dir=model_args.cache_dir,
model_max_length=model_args.model_max_length,
use_auth_token=use_auth_token,
trust_remote_code=True,
)
# NOTE: when load weights from existing sca model, we should use the same tokenizer as the existing sca model
# use `python scripts/tools/get_sub_model_name_from_ckpt.py $$BEST_CKPT_PATH $MODEL_TYPE` to get the model_type.
elif isinstance(model_args, SCAModelBaseArguments):
processor = ScaProcessor.from_sam_text_pretrained(
model_args.sam_model_name_or_path,
model_args.lm_head_model_name_or_path,
cache_dir=model_args.cache_dir,
model_max_length=model_args.model_max_length,
use_auth_token=use_auth_token,
trust_remote_code=True,
)
else:
raise ValueError(
f"model_args must be one of [SAMCaptionerModelArguments, SCAModelBaseArguments], got {type(model_args)}"
)
# NOTE(xiaoke): add pad_token if not exists
if processor.tokenizer.pad_token is None:
if processor.tokenizer.eos_token is None:
raise ValueError("tokenizer must have either eos_token")
processor.tokenizer.pad_token = processor.tokenizer.eos_token
return processor
if __name__ == "__main__":
main()