File size: 8,714 Bytes
bad0757 0d4682d bad0757 0d4682d bad0757 0d4682d bad0757 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
---
datasets:
- gaussalgo/Canard_Wiki-augmented
- hotpot_qa
metrics:
- rouge
- bleu
model-index:
- name: T5-LM-Large_Canard-Fullwiki-HotpotQA-rephrase
results:
- task:
type: question-answering
name: Question Answering
dataset:
type: hotpot_qa
name: HotpotQA
split: validation
metrics:
- type: rouge
value: 0.4774
- type: bleu
value: 29.11
- task:
type: question-answering
name: Question Answering
dataset:
type: gaussalgo/Canard_Wiki-augmented
name: Wikipedia-augmented Conversational QA (Canard)
split: validation
metrics:
- type: rouge
value: 0.4377
- type: bleu
value: 19.34
license: cc-by-sa-4.0
language:
- en
---
# Model Card for T5-LM-Large_Canard-HotpotQA-rephrase
This model is trained on three objectives:
1. Generating answers for Canard dataset based on Wikipedia search results
2. Generating answers for HotpotQA,
3. Rephrasing questions by the conversation context.
## Training
The model was trained using the following script, which can be copy-pasted and run as-is (with the installed `requirements.txt`).
All details, including the request format, can be inferred without errors from the code.
The best checkpoint was picked by a maximum ROUGE on Canard conversational QA's ROUGE.
```python
import datasets
canard_train_augm = datasets.load_dataset("gaussalgo/Canard_Wiki-augmented", split="train")
canard_test_augm = datasets.load_dataset("gaussalgo/Canard_Wiki-augmented", split="test")
canard_df = canard_train_augm.to_pandas()
canard_test_df = canard_train_augm.to_pandas()
### Curation of seq2seq input contexts and labels
import random
def input_context_from_sample(row: dict, max_length=5) -> str:
context = "Previous conversation:"
context += "\nQuestion: "
context += ", ".join(row["History"][:3])
for i in range(3, len(row["History"]), 2):
context += "\nAnswer: "
context += row["History"][i]
if i+1 < len(row["History"]):
context += "\nQuestion: "
context += row["History"][i+1]
context += "\n\nCurrent Question: "
context += row["Question"]
context += "\nSearch results:"
all_contexts = row["retrieved_contexts"].tolist()[:max_length-1] + [row["true_contexts"]]
random.shuffle(all_contexts)
for i, search_result in enumerate(all_contexts):
context += "\n[%s]: " % (i+1)
context += search_result.replace("CANNOTANSWER", "")
context += "\nCurrent Answer: "
return context
def rephrasing_context_from_sample(row: dict) -> str:
context = "Previous conversation:"
context += "\nQuestion: "
context += ", ".join(row["History"][:3])
for i in range(3, len(row["History"]), 2):
context += "\nAnswer: "
context += row["History"][i]
if i+1 < len(row["History"]):
context += "\nQuestion: "
context += row["History"][i+1]
context += "\n\nCurrent Question: "
context += row["Question"]
context += "\nMore specific question: "
return context
def hotpotqa_context(row: dict) -> str:
context = "Current Question: "
context += row["question"]
context += "\nSearch results:"
all_contexts = [" ".join(context) for context in row["context"]["sentences"]]
for i, search_result in enumerate(all_contexts):
context += "\n[%s]: " % (i+1)
context += search_result.replace("CANNOTANSWER", "")
context += "\nCurrent Answer: "
return context
# Conversational QA sequences
input_texts = canard_df.apply(lambda row: input_context_from_sample(row), axis=1).values
input_val_texts = canard_test_df.iloc[:200].apply(lambda row: input_context_from_sample(row), axis=1).values
too_long_index = [len(t) > 20000 for t in input_texts]
input_texts = [t for i, t in enumerate(input_texts) if not too_long_index[i]]
# print(too_long_index)
print("training on %s samples" % len(input_texts))
labels = canard_df.answer.apply(lambda ans: "No answer" if ans == "CANNOTANSWER" else ans).values
labels = [l for i, l in enumerate(labels) if not too_long_index[i]]
val_labels = canard_test_df.answer.apply(lambda ans: "No answer" if ans == "CANNOTANSWER" else ans).values
# Rephrasing sequences
rephrasing_inputs = canard_df.apply(lambda row: rephrasing_context_from_sample(row), axis=1).values
rephrasing_val_inputs = canard_test_df.apply(lambda row: rephrasing_context_from_sample(row), axis=1).values
rephrasing_labels = canard_df.Rewrite.values
rephrasing_val_labels = canard_test_df.Rewrite.values
# HotpotQA sequences
hotpot_train = datasets.load_dataset("hotpot_qa", "distractor")["train"]
hotpot_val = datasets.load_dataset("hotpot_qa", "distractor")["validation"]
hotpot_inputs = hotpot_train.to_pandas().apply(hotpotqa_context, axis=1)
hotpot_val_inputs = hotpot_val.to_pandas().apply(hotpotqa_context, axis=1)
too_long_index = [len(t) > 20000 for t in hotpot_inputs]
hotpot_inputs = [t for i, t in enumerate(hotpot_inputs) if not too_long_index[i]]
hotpot_answers = [t for i, t in enumerate(hotpot_train["answer"]) if not too_long_index[i]]
# Training routine
# see Adaptor's homepage for details:
# https://github.com/gaussalgo/adaptor
# Base model
from adaptor.lang_module import LangModule
lang_module = LangModule("google/t5-large-lm-adapt")
from adaptor.evaluators.generative import ROUGE, BLEU
# Evaluations
evaluators = [BLEU(), ROUGE(decides_convergence=True)]
# Objectives
from adaptor.objectives.seq2seq import Sequence2Sequence
seq_qa = Sequence2Sequence(lang_module,
texts_or_path=input_texts,
labels_or_path=labels,
val_texts_or_path=input_val_texts,
val_labels_or_path=val_labels,
batch_size=4,
val_evaluators=evaluators,
objective_id="Canard")
seq_additional_qa = Sequence2Sequence(lang_module,
texts_or_path=hotpot_inputs,
labels_or_path=hotpot_answers,
val_texts_or_path=hotpot_val_inputs[:200],
val_labels_or_path=hotpot_val["answer"][:200],
batch_size=4,
val_evaluators=evaluators,
objective_id="HotpotQA",
share_other_objective_head=seq_qa)
seq_rephrasing = Sequence2Sequence(lang_module,
texts_or_path=rephrasing_inputs,
labels_or_path=rephrasing_labels,
val_texts_or_path=rephrasing_val_inputs[:200],
val_labels_or_path=rephrasing_val_labels[:200],
batch_size=4,
val_evaluators=evaluators,
objective_id="rephrasing",
share_other_objective_head=seq_qa)
# Training schedule & arguments
from adaptor.utils import AdaptationArguments, StoppingStrategy
training_arguments = AdaptationArguments(output_dir="checkpoints-chatbot",
learning_rate=5e-5,
stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED,
stopping_patience=8,
save_total_limit=8,
do_train=True,
do_eval=True,
bf16=True,
warmup_steps=1000,
gradient_accumulation_steps=8,
logging_steps=10,
eval_steps=200,
save_steps=1000,
num_train_epochs=10,
evaluation_strategy="steps")
from adaptor.schedules import ParallelSchedule
from adaptor.adapter import Adapter
schedule = ParallelSchedule(objectives=[seq_qa, seq_additional_qa, seq_rephrasing],
args=training_arguments)
adapter = Adapter(lang_module, schedule, args=training_arguments)
adapter.train() # Training for 63k updates
```
## Usage
See the prompting templates used in training to infer the optimal prompting format.
#### Contact
Feel free to ask questions here, or at stefanik{at} gaussalgo.com |