File size: 3,954 Bytes
a0db2f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import json
import logging
import os
import sys
from pathlib import Path
import finetune_rag
from transformers.file_utils import is_apex_available
from transformers.testing_utils import (
TestCasePlus,
execute_subprocess_async,
require_ray,
require_torch_gpu,
require_torch_multi_gpu,
)
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class RagFinetuneExampleTests(TestCasePlus):
def _create_dummy_data(self, data_dir):
os.makedirs(data_dir, exist_ok=True)
contents = {"source": "What is love ?", "target": "life"}
n_lines = {"train": 12, "val": 2, "test": 2}
for split in ["train", "test", "val"]:
for field in ["source", "target"]:
content = "\n".join([contents[field]] * n_lines[split])
with open(os.path.join(data_dir, f"{split}.{field}"), "w") as f:
f.write(content)
def _run_finetune(self, gpus: int, distributed_retriever: str = "pytorch"):
tmp_dir = self.get_auto_remove_tmp_dir()
output_dir = os.path.join(tmp_dir, "output")
data_dir = os.path.join(tmp_dir, "data")
self._create_dummy_data(data_dir=data_dir)
testargs = f"""
--data_dir {data_dir} \
--output_dir {output_dir} \
--model_name_or_path facebook/rag-sequence-base \
--model_type rag_sequence \
--do_train \
--do_predict \
--n_val -1 \
--val_check_interval 1.0 \
--train_batch_size 2 \
--eval_batch_size 1 \
--max_source_length 25 \
--max_target_length 25 \
--val_max_target_length 25 \
--test_max_target_length 25 \
--label_smoothing 0.1 \
--dropout 0.1 \
--attention_dropout 0.1 \
--weight_decay 0.001 \
--adam_epsilon 1e-08 \
--max_grad_norm 0.1 \
--lr_scheduler polynomial \
--learning_rate 3e-04 \
--num_train_epochs 1 \
--warmup_steps 4 \
--gradient_accumulation_steps 1 \
--distributed-port 8787 \
--use_dummy_dataset 1 \
--distributed_retriever {distributed_retriever} \
""".split()
if gpus > 0:
testargs.append(f"--gpus={gpus}")
if is_apex_available():
testargs.append("--fp16")
else:
testargs.append("--gpus=0")
testargs.append("--distributed_backend=ddp_cpu")
testargs.append("--num_processes=2")
cmd = [sys.executable, str(Path(finetune_rag.__file__).resolve())] + testargs
execute_subprocess_async(cmd, env=self.get_env())
metrics_save_path = os.path.join(output_dir, "metrics.json")
with open(metrics_save_path) as f:
result = json.load(f)
return result
@require_torch_gpu
def test_finetune_gpu(self):
result = self._run_finetune(gpus=1)
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
@require_torch_multi_gpu
def test_finetune_multigpu(self):
result = self._run_finetune(gpus=2)
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
@require_torch_gpu
@require_ray
def test_finetune_gpu_ray_retrieval(self):
result = self._run_finetune(gpus=1, distributed_retriever="ray")
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
@require_torch_multi_gpu
@require_ray
def test_finetune_multigpu_ray_retrieval(self):
result = self._run_finetune(gpus=1, distributed_retriever="ray")
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
|