|
import json |
|
import logging |
|
import os |
|
import sys |
|
from pathlib import Path |
|
|
|
import finetune_rag |
|
|
|
from transformers.file_utils import is_apex_available |
|
from transformers.testing_utils import ( |
|
TestCasePlus, |
|
execute_subprocess_async, |
|
require_ray, |
|
require_torch_gpu, |
|
require_torch_multi_gpu, |
|
) |
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
logger = logging.getLogger() |
|
|
|
stream_handler = logging.StreamHandler(sys.stdout) |
|
logger.addHandler(stream_handler) |
|
|
|
|
|
class RagFinetuneExampleTests(TestCasePlus): |
|
def _create_dummy_data(self, data_dir): |
|
os.makedirs(data_dir, exist_ok=True) |
|
contents = {"source": "What is love ?", "target": "life"} |
|
n_lines = {"train": 12, "val": 2, "test": 2} |
|
for split in ["train", "test", "val"]: |
|
for field in ["source", "target"]: |
|
content = "\n".join([contents[field]] * n_lines[split]) |
|
with open(os.path.join(data_dir, f"{split}.{field}"), "w") as f: |
|
f.write(content) |
|
|
|
def _run_finetune(self, gpus: int, distributed_retriever: str = "pytorch"): |
|
tmp_dir = self.get_auto_remove_tmp_dir() |
|
output_dir = os.path.join(tmp_dir, "output") |
|
data_dir = os.path.join(tmp_dir, "data") |
|
self._create_dummy_data(data_dir=data_dir) |
|
|
|
testargs = f""" |
|
--data_dir {data_dir} \ |
|
--output_dir {output_dir} \ |
|
--model_name_or_path facebook/rag-sequence-base \ |
|
--model_type rag_sequence \ |
|
--do_train \ |
|
--do_predict \ |
|
--n_val -1 \ |
|
--val_check_interval 1.0 \ |
|
--train_batch_size 2 \ |
|
--eval_batch_size 1 \ |
|
--max_source_length 25 \ |
|
--max_target_length 25 \ |
|
--val_max_target_length 25 \ |
|
--test_max_target_length 25 \ |
|
--label_smoothing 0.1 \ |
|
--dropout 0.1 \ |
|
--attention_dropout 0.1 \ |
|
--weight_decay 0.001 \ |
|
--adam_epsilon 1e-08 \ |
|
--max_grad_norm 0.1 \ |
|
--lr_scheduler polynomial \ |
|
--learning_rate 3e-04 \ |
|
--num_train_epochs 1 \ |
|
--warmup_steps 4 \ |
|
--gradient_accumulation_steps 1 \ |
|
--distributed-port 8787 \ |
|
--use_dummy_dataset 1 \ |
|
--distributed_retriever {distributed_retriever} \ |
|
""".split() |
|
|
|
if gpus > 0: |
|
testargs.append(f"--gpus={gpus}") |
|
if is_apex_available(): |
|
testargs.append("--fp16") |
|
else: |
|
testargs.append("--gpus=0") |
|
testargs.append("--distributed_backend=ddp_cpu") |
|
testargs.append("--num_processes=2") |
|
|
|
cmd = [sys.executable, str(Path(finetune_rag.__file__).resolve())] + testargs |
|
execute_subprocess_async(cmd, env=self.get_env()) |
|
|
|
metrics_save_path = os.path.join(output_dir, "metrics.json") |
|
with open(metrics_save_path) as f: |
|
result = json.load(f) |
|
return result |
|
|
|
@require_torch_gpu |
|
def test_finetune_gpu(self): |
|
result = self._run_finetune(gpus=1) |
|
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2) |
|
|
|
@require_torch_multi_gpu |
|
def test_finetune_multigpu(self): |
|
result = self._run_finetune(gpus=2) |
|
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2) |
|
|
|
@require_torch_gpu |
|
@require_ray |
|
def test_finetune_gpu_ray_retrieval(self): |
|
result = self._run_finetune(gpus=1, distributed_retriever="ray") |
|
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2) |
|
|
|
@require_torch_multi_gpu |
|
@require_ray |
|
def test_finetune_multigpu_ray_retrieval(self): |
|
result = self._run_finetune(gpus=1, distributed_retriever="ray") |
|
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2) |
|
|