import json import logging import os import sys from pathlib import Path import finetune_rag from transformers.file_utils import is_apex_available from transformers.testing_utils import ( TestCasePlus, execute_subprocess_async, require_ray, require_torch_gpu, require_torch_multi_gpu, ) logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger() stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) class RagFinetuneExampleTests(TestCasePlus): def _create_dummy_data(self, data_dir): os.makedirs(data_dir, exist_ok=True) contents = {"source": "What is love ?", "target": "life"} n_lines = {"train": 12, "val": 2, "test": 2} for split in ["train", "test", "val"]: for field in ["source", "target"]: content = "\n".join([contents[field]] * n_lines[split]) with open(os.path.join(data_dir, f"{split}.{field}"), "w") as f: f.write(content) def _run_finetune(self, gpus: int, distributed_retriever: str = "pytorch"): tmp_dir = self.get_auto_remove_tmp_dir() output_dir = os.path.join(tmp_dir, "output") data_dir = os.path.join(tmp_dir, "data") self._create_dummy_data(data_dir=data_dir) testargs = f""" --data_dir {data_dir} \ --output_dir {output_dir} \ --model_name_or_path facebook/rag-sequence-base \ --model_type rag_sequence \ --do_train \ --do_predict \ --n_val -1 \ --val_check_interval 1.0 \ --train_batch_size 2 \ --eval_batch_size 1 \ --max_source_length 25 \ --max_target_length 25 \ --val_max_target_length 25 \ --test_max_target_length 25 \ --label_smoothing 0.1 \ --dropout 0.1 \ --attention_dropout 0.1 \ --weight_decay 0.001 \ --adam_epsilon 1e-08 \ --max_grad_norm 0.1 \ --lr_scheduler polynomial \ --learning_rate 3e-04 \ --num_train_epochs 1 \ --warmup_steps 4 \ --gradient_accumulation_steps 1 \ --distributed-port 8787 \ --use_dummy_dataset 1 \ --distributed_retriever {distributed_retriever} \ """.split() if gpus > 0: testargs.append(f"--gpus={gpus}") if is_apex_available(): testargs.append("--fp16") else: testargs.append("--gpus=0") testargs.append("--distributed_backend=ddp_cpu") testargs.append("--num_processes=2") cmd = [sys.executable, str(Path(finetune_rag.__file__).resolve())] + testargs execute_subprocess_async(cmd, env=self.get_env()) metrics_save_path = os.path.join(output_dir, "metrics.json") with open(metrics_save_path) as f: result = json.load(f) return result @require_torch_gpu def test_finetune_gpu(self): result = self._run_finetune(gpus=1) self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2) @require_torch_multi_gpu def test_finetune_multigpu(self): result = self._run_finetune(gpus=2) self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2) @require_torch_gpu @require_ray def test_finetune_gpu_ray_retrieval(self): result = self._run_finetune(gpus=1, distributed_retriever="ray") self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2) @require_torch_multi_gpu @require_ray def test_finetune_multigpu_ray_retrieval(self): result = self._run_finetune(gpus=1, distributed_retriever="ray") self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)