|
--- |
|
datasets: |
|
- gaussalgo/Canard_Wiki-augmented |
|
- hotpot_qa |
|
metrics: |
|
- rouge |
|
- bleu |
|
model-index: |
|
- name: T5-LM-Large_Canard-Fullwiki-HotpotQA-rephrase |
|
results: |
|
- task: |
|
type: question-answering |
|
name: Question Answering |
|
dataset: |
|
type: hotpot_qa |
|
name: HotpotQA |
|
split: validation |
|
metrics: |
|
- type: rouge |
|
value: 0.4774 |
|
- type: bleu |
|
value: 29.11 |
|
- task: |
|
type: question-answering |
|
name: Question Answering |
|
dataset: |
|
type: gaussalgo/Canard_Wiki-augmented |
|
name: Wikipedia-augmented Conversational QA (Canard) |
|
split: validation |
|
metrics: |
|
- type: rouge |
|
value: 0.4377 |
|
- type: bleu |
|
value: 19.34 |
|
license: cc-by-sa-4.0 |
|
language: |
|
- en |
|
--- |
|
|
|
# Model Card for T5-LM-Large_Canard-HotpotQA-rephrase |
|
This model is trained on three objectives: |
|
1. Generating answers for Canard dataset based on Wikipedia search results |
|
2. Generating answers for HotpotQA, |
|
3. Rephrasing questions by the conversation context. |
|
|
|
## Training |
|
The model was trained using the following script, which can be copy-pasted and run as-is (with the installed `requirements.txt`). |
|
All details, including the request format, can be inferred without errors from the code. |
|
The best checkpoint was picked by a maximum ROUGE on Canard conversational QA's ROUGE. |
|
|
|
```python |
|
import datasets |
|
|
|
canard_train_augm = datasets.load_dataset("gaussalgo/Canard_Wiki-augmented", split="train") |
|
canard_test_augm = datasets.load_dataset("gaussalgo/Canard_Wiki-augmented", split="test") |
|
|
|
canard_df = canard_train_augm.to_pandas() |
|
canard_test_df = canard_train_augm.to_pandas() |
|
|
|
|
|
### Curation of seq2seq input contexts and labels |
|
import random |
|
|
|
def input_context_from_sample(row: dict, max_length=5) -> str: |
|
context = "Previous conversation:" |
|
context += "\nQuestion: " |
|
context += ", ".join(row["History"][:3]) |
|
for i in range(3, len(row["History"]), 2): |
|
context += "\nAnswer: " |
|
context += row["History"][i] |
|
if i+1 < len(row["History"]): |
|
context += "\nQuestion: " |
|
context += row["History"][i+1] |
|
|
|
context += "\n\nCurrent Question: " |
|
context += row["Question"] |
|
|
|
context += "\nSearch results:" |
|
all_contexts = row["retrieved_contexts"].tolist()[:max_length-1] + [row["true_contexts"]] |
|
random.shuffle(all_contexts) |
|
|
|
for i, search_result in enumerate(all_contexts): |
|
context += "\n[%s]: " % (i+1) |
|
context += search_result.replace("CANNOTANSWER", "") |
|
|
|
context += "\nCurrent Answer: " |
|
return context |
|
|
|
def rephrasing_context_from_sample(row: dict) -> str: |
|
context = "Previous conversation:" |
|
context += "\nQuestion: " |
|
context += ", ".join(row["History"][:3]) |
|
for i in range(3, len(row["History"]), 2): |
|
context += "\nAnswer: " |
|
context += row["History"][i] |
|
if i+1 < len(row["History"]): |
|
context += "\nQuestion: " |
|
context += row["History"][i+1] |
|
|
|
context += "\n\nCurrent Question: " |
|
context += row["Question"] |
|
|
|
context += "\nMore specific question: " |
|
return context |
|
|
|
def hotpotqa_context(row: dict) -> str: |
|
context = "Current Question: " |
|
context += row["question"] |
|
|
|
context += "\nSearch results:" |
|
all_contexts = [" ".join(context) for context in row["context"]["sentences"]] |
|
|
|
for i, search_result in enumerate(all_contexts): |
|
context += "\n[%s]: " % (i+1) |
|
context += search_result.replace("CANNOTANSWER", "") |
|
|
|
context += "\nCurrent Answer: " |
|
return context |
|
|
|
# Conversational QA sequences |
|
input_texts = canard_df.apply(lambda row: input_context_from_sample(row), axis=1).values |
|
input_val_texts = canard_test_df.iloc[:200].apply(lambda row: input_context_from_sample(row), axis=1).values |
|
|
|
too_long_index = [len(t) > 20000 for t in input_texts] |
|
input_texts = [t for i, t in enumerate(input_texts) if not too_long_index[i]] |
|
# print(too_long_index) |
|
print("training on %s samples" % len(input_texts)) |
|
|
|
labels = canard_df.answer.apply(lambda ans: "No answer" if ans == "CANNOTANSWER" else ans).values |
|
labels = [l for i, l in enumerate(labels) if not too_long_index[i]] |
|
val_labels = canard_test_df.answer.apply(lambda ans: "No answer" if ans == "CANNOTANSWER" else ans).values |
|
|
|
# Rephrasing sequences |
|
rephrasing_inputs = canard_df.apply(lambda row: rephrasing_context_from_sample(row), axis=1).values |
|
rephrasing_val_inputs = canard_test_df.apply(lambda row: rephrasing_context_from_sample(row), axis=1).values |
|
|
|
rephrasing_labels = canard_df.Rewrite.values |
|
rephrasing_val_labels = canard_test_df.Rewrite.values |
|
|
|
# HotpotQA sequences |
|
hotpot_train = datasets.load_dataset("hotpot_qa", "distractor")["train"] |
|
hotpot_val = datasets.load_dataset("hotpot_qa", "distractor")["validation"] |
|
|
|
hotpot_inputs = hotpot_train.to_pandas().apply(hotpotqa_context, axis=1) |
|
hotpot_val_inputs = hotpot_val.to_pandas().apply(hotpotqa_context, axis=1) |
|
too_long_index = [len(t) > 20000 for t in hotpot_inputs] |
|
|
|
hotpot_inputs = [t for i, t in enumerate(hotpot_inputs) if not too_long_index[i]] |
|
hotpot_answers = [t for i, t in enumerate(hotpot_train["answer"]) if not too_long_index[i]] |
|
|
|
# Training routine |
|
# see Adaptor's homepage for details: |
|
# https://github.com/gaussalgo/adaptor |
|
|
|
# Base model |
|
from adaptor.lang_module import LangModule |
|
lang_module = LangModule("google/t5-large-lm-adapt") |
|
|
|
from adaptor.evaluators.generative import ROUGE, BLEU |
|
|
|
# Evaluations |
|
evaluators = [BLEU(), ROUGE(decides_convergence=True)] |
|
|
|
# Objectives |
|
from adaptor.objectives.seq2seq import Sequence2Sequence |
|
|
|
seq_qa = Sequence2Sequence(lang_module, |
|
texts_or_path=input_texts, |
|
labels_or_path=labels, |
|
val_texts_or_path=input_val_texts, |
|
val_labels_or_path=val_labels, |
|
batch_size=4, |
|
val_evaluators=evaluators, |
|
objective_id="Canard") |
|
|
|
seq_additional_qa = Sequence2Sequence(lang_module, |
|
texts_or_path=hotpot_inputs, |
|
labels_or_path=hotpot_answers, |
|
val_texts_or_path=hotpot_val_inputs[:200], |
|
val_labels_or_path=hotpot_val["answer"][:200], |
|
batch_size=4, |
|
val_evaluators=evaluators, |
|
objective_id="HotpotQA", |
|
share_other_objective_head=seq_qa) |
|
|
|
seq_rephrasing = Sequence2Sequence(lang_module, |
|
texts_or_path=rephrasing_inputs, |
|
labels_or_path=rephrasing_labels, |
|
val_texts_or_path=rephrasing_val_inputs[:200], |
|
val_labels_or_path=rephrasing_val_labels[:200], |
|
batch_size=4, |
|
val_evaluators=evaluators, |
|
objective_id="rephrasing", |
|
share_other_objective_head=seq_qa) |
|
|
|
# Training schedule & arguments |
|
from adaptor.utils import AdaptationArguments, StoppingStrategy |
|
|
|
training_arguments = AdaptationArguments(output_dir="checkpoints-chatbot", |
|
learning_rate=5e-5, |
|
stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED, |
|
stopping_patience=8, |
|
save_total_limit=8, |
|
do_train=True, |
|
do_eval=True, |
|
bf16=True, |
|
warmup_steps=1000, |
|
gradient_accumulation_steps=8, |
|
logging_steps=10, |
|
eval_steps=200, |
|
save_steps=1000, |
|
num_train_epochs=10, |
|
evaluation_strategy="steps") |
|
from adaptor.schedules import ParallelSchedule |
|
from adaptor.adapter import Adapter |
|
|
|
schedule = ParallelSchedule(objectives=[seq_qa, seq_additional_qa, seq_rephrasing], |
|
args=training_arguments) |
|
adapter = Adapter(lang_module, schedule, args=training_arguments) |
|
adapter.train() # Training for 63k updates |
|
``` |
|
|
|
## Usage |
|
See the prompting templates used in training to infer the optimal prompting format. |
|
|
|
#### Contact |
|
Feel free to ask questions here, or at stefanik{at} gaussalgo.com |