Spaces:
Runtime error
Runtime error
| # HuggingFace trainer | |
| import json | |
| import os | |
| from dataclasses import dataclass | |
| from typing import Callable, Dict, Optional | |
| import torch | |
| from datasets import concatenate_datasets | |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from transformers import AutoTokenizer, Idefics2Processor, PreTrainedModel, PreTrainedTokenizer, TrainingArguments | |
| from colpali_engine.dataset.custom_collator import CustomCollator | |
| from colpali_engine.loss.colbert_loss import BiEncoderLoss, BiPairwiseCELoss, ColbertLoss, ColbertPairwiseCELoss | |
| from colpali_engine.trainer.contrastive_trainer import ContrastiveTrainer | |
| from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator | |
| from colpali_engine.utils.gpu_stats import print_gpu_utilization, print_summary | |
| class ColModelTrainingConfig: | |
| model: PreTrainedModel | |
| tr_args: TrainingArguments = None | |
| output_dir: str = None | |
| max_length: int = 256 | |
| run_eval: bool = True | |
| run_train: bool = True | |
| peft_config: Optional[LoraConfig] = None | |
| add_suffix: bool = False | |
| processor: Idefics2Processor = None | |
| tokenizer: PreTrainedTokenizer = None | |
| loss_func: Optional[Callable] = ColbertLoss() | |
| dataset_loading_func: Optional[Callable] = None | |
| eval_dataset_loader: Optional[Dict[str, Callable]] = None | |
| pretrained_peft_model_name_or_path: Optional[str] = None | |
| def __post_init__(self): | |
| if self.output_dir is None: | |
| sanitized_name = str(self.model.name_or_path).replace("/", "_") | |
| self.output_dir = f"./models/{sanitized_name}" | |
| if self.tr_args is None: | |
| self.tr_args = TrainingArguments(output_dir=self.output_dir) | |
| elif self.tr_args.output_dir is None: | |
| self.tr_args.output_dir = self.output_dir | |
| # cast if string | |
| if isinstance(self.tr_args.learning_rate, str): | |
| self.tr_args.learning_rate = float(self.tr_args.learning_rate) | |
| self.tr_args.remove_unused_columns = False | |
| if self.processor is None and self.tokenizer is None: | |
| print("Using textual model tokenization") | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model.name_or_path) | |
| if self.pretrained_peft_model_name_or_path is not None: | |
| self.model.load_adapter(self.pretrained_peft_model_name_or_path) | |
| print(f"Loaded pretrained adapter from {self.pretrained_peft_model_name_or_path}") | |
| if self.peft_config is not None: | |
| print("Configurating PEFT model") | |
| if self.processor is None: | |
| # Might be deprecated - use the "else" branch | |
| self.model = prepare_model_for_kbit_training(self.model) # use_gradient_checkpointing=True | |
| # self.model.enable_input_require_grads() | |
| self.model = get_peft_model(self.model, self.peft_config) | |
| self.model.print_trainable_parameters() | |
| else: | |
| # Ugly debugging hack | |
| # if self.model.model.config.text_config.vocab_size == 32000: | |
| # print("DEBUG: Resizing token embeddings - This should not happen in a real scenario!") | |
| # self.model.model.text_model.resize_token_embeddings(32003) | |
| # self.model.model.vision_model.encoder.layers = self.model.model.vision_model.encoder.layers[0:2] | |
| # self.model.enable_input_require_grads() | |
| if self.pretrained_peft_model_name_or_path is None: | |
| self.model.add_adapter(self.peft_config) | |
| self.model.enable_adapters() | |
| else: | |
| print(f"Adapter already loaded from {self.pretrained_peft_model_name_or_path}. Not overwriting.") | |
| print_gpu_utilization() | |
| class ColModelTraining: | |
| def __init__(self, config: ColModelTrainingConfig) -> None: | |
| self.config = config | |
| self.model = self.config.model | |
| self.dataset = self.config.dataset_loading_func() | |
| self.collator = CustomCollator( | |
| processor=self.config.processor, tokenizer=self.config.tokenizer, max_length=self.config.max_length | |
| ) | |
| self.current_git_hash = os.popen("git rev-parse HEAD").read().strip() | |
| self.retriever_evaluator = CustomEvaluator( | |
| is_multi_vector=( | |
| isinstance(self.config.loss_func, ColbertLoss) | |
| or isinstance(self.config.loss_func, ColbertPairwiseCELoss) | |
| ) | |
| ) | |
| def train(self) -> None: | |
| trainer = ContrastiveTrainer( | |
| model=self.model, | |
| train_dataset=self.dataset["train"], | |
| eval_dataset=self.dataset["test"], | |
| args=self.config.tr_args, | |
| data_collator=self.collator, | |
| loss_func=self.config.loss_func, | |
| is_vision_model=self.config.processor is not None, | |
| ) | |
| trainer.args.remove_unused_columns = False | |
| result = trainer.train() | |
| print_summary(result) | |
| def eval_dataset(self, test_dataset): | |
| self.model.eval() | |
| # # debug | |
| # if len(test_dataset) > 200: | |
| # test_dataset = test_dataset.select(range(0, 100)) | |
| idx_with_query = [idx for idx, sample in enumerate(test_dataset["query"]) if sample is not None] | |
| idx_without_query = [idx for idx, sample in enumerate(test_dataset["query"]) if sample is None] | |
| dataloader_with_query = DataLoader( | |
| test_dataset.select(idx_with_query), | |
| batch_size=self.config.tr_args.per_device_eval_batch_size, | |
| shuffle=False, | |
| collate_fn=self.collator, | |
| ) | |
| dataloader_without_query = DataLoader( | |
| test_dataset.select(idx_without_query), | |
| batch_size=self.config.tr_args.per_device_eval_batch_size, | |
| shuffle=False, | |
| collate_fn=self.collator, | |
| ) | |
| # dataset is ordered so that non-null queries come first | |
| test_dataset = concatenate_datasets( | |
| [test_dataset.select(idx_with_query), test_dataset.select(idx_without_query)] | |
| ) | |
| relevant_docs = {} | |
| docidx_2_docid = {} | |
| qsidx_2_query = [] | |
| for idx, sample in enumerate(test_dataset): | |
| doc_id = sample["image_filename"] if "image_filename" in sample else str(hash(sample["doc"])) | |
| # query_id = sample["query_id"] if "query_id" in sample else str(hash(sample["query"])) | |
| if sample["query"] is not None: | |
| relevant_docs[str(idx)] = {doc_id: 1} | |
| qsidx_2_query.append(str(idx)) | |
| docidx_2_docid[str(idx)] = doc_id | |
| qs = [] | |
| ps = [] | |
| device = self.model.device | |
| with (torch.no_grad()): | |
| for dataloader in [dataloader_with_query, dataloader_without_query]: | |
| for batch in tqdm(dataloader): | |
| if "doc_pixel_values" not in batch: | |
| doc = self.model( | |
| input_ids=batch["doc_input_ids"].to(device), | |
| attention_mask=batch["doc_attention_mask"].to(device), | |
| ) | |
| else: | |
| if "doc_pixel_attention_mask" in batch: | |
| doc = self.model( | |
| input_ids=batch["doc_input_ids"].to(device), | |
| attention_mask=batch["doc_attention_mask"].to(device), | |
| pixel_values=batch["doc_pixel_values"].to(device), | |
| pixel_attention_mask=batch["doc_pixel_attention_mask"].to(device), | |
| ) | |
| else: | |
| doc = self.model( | |
| input_ids=batch["doc_input_ids"].to(device), | |
| attention_mask=batch["doc_attention_mask"].to(device), | |
| pixel_values=batch["doc_pixel_values"].to(device), | |
| ) | |
| ps.extend(list(torch.unbind(doc.to("cpu")))) | |
| if "query_input_ids" in batch: | |
| query = self.model( | |
| input_ids=batch["query_input_ids"].to(device), | |
| attention_mask=batch["query_attention_mask"].to(device), | |
| ) | |
| # variable len | |
| qs.extend(list(torch.unbind(query.to("cpu")))) | |
| print("Embeddings computed, evaluating") | |
| scores = self.retriever_evaluator.evaluate(qs, ps) | |
| # scores is 2d array of shape (n_queries, n_docs) | |
| # turn it into a dict | |
| results = {} | |
| assert scores.shape[0] == len(qsidx_2_query) | |
| for idx, scores_per_query in enumerate(scores): | |
| results[qsidx_2_query[idx]] = { | |
| docidx_2_docid[str(docidx)]: float(score) for docidx, score in enumerate(scores_per_query) | |
| } | |
| # evaluate | |
| metrics = self.retriever_evaluator.compute_metrics(relevant_docs, results) | |
| print(metrics) | |
| return metrics | |
| def eval(self) -> None: | |
| print("Evaluating on validation set") | |
| metrics = self.eval_dataset(self.dataset["test"]) | |
| print(f"Metrics for validation set: {metrics}") | |
| all_metrics = {"validation_set": metrics} | |
| if self.config.eval_dataset_loader is not None: | |
| for test_name, test_dataset_loading_func in self.config.eval_dataset_loader.items(): | |
| print(f"Evaluating {test_name}") | |
| test_ds = test_dataset_loading_func() | |
| metrics = self.eval_dataset(test_ds) | |
| all_metrics[test_name] = metrics | |
| print(f"Metrics for {test_name}: {metrics}") | |
| # checkpoint dumps | |
| with open(f"{self.config.output_dir}/results.json", "w") as f: | |
| json.dump(all_metrics, f) | |
| # save results as json | |
| with open(f"{self.config.output_dir}/results.json", "w") as f: | |
| json.dump(all_metrics, f) | |
| def save(self, config_file): | |
| # save model | |
| self.model.save_pretrained(self.config.output_dir) | |
| if self.config.tokenizer is not None: | |
| self.config.tokenizer.save_pretrained(self.config.output_dir) | |
| if self.config.processor is not None: | |
| self.config.processor.save_pretrained(self.config.output_dir) # save config | |
| # copy-paste the yml file with os | |
| os.system(f"cp {config_file} {self.config.output_dir}/training_config.yml") | |
| # save git hash of the commit at beginning of training | |
| with open(f"{self.config.output_dir}/git_hash.txt", "w") as f: | |
| f.write(self.current_git_hash) | |