File size: 8,714 Bytes
bad0757
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d4682d
 
 
bad0757
 
0d4682d
 
bad0757
0d4682d
bad0757
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
---
datasets:
- gaussalgo/Canard_Wiki-augmented
- hotpot_qa
metrics:
- rouge
- bleu
model-index:
- name: T5-LM-Large_Canard-Fullwiki-HotpotQA-rephrase
  results:
  - task:
      type: question-answering
      name: Question Answering
    dataset:
      type: hotpot_qa
      name: HotpotQA
      split: validation
    metrics:
    - type: rouge
      value: 0.4774
    - type: bleu
      value: 29.11
  - task:
      type: question-answering
      name: Question Answering
    dataset:
      type: gaussalgo/Canard_Wiki-augmented
      name: Wikipedia-augmented Conversational QA (Canard)
      split: validation
    metrics:
    - type: rouge
      value: 0.4377
    - type: bleu
      value: 19.34
license: cc-by-sa-4.0
language:
- en
---

# Model Card for T5-LM-Large_Canard-HotpotQA-rephrase 
This model is trained on three objectives: 
  1. Generating answers for Canard dataset based on Wikipedia search results 
  2. Generating answers for HotpotQA, 
  3. Rephrasing questions by the conversation context.

## Training
The model was trained using the following script, which can be copy-pasted and run as-is (with the installed `requirements.txt`). 
All details, including the request format, can be inferred without errors from the code.
The best checkpoint was picked by a maximum ROUGE on Canard conversational QA's ROUGE.

```python
import datasets

canard_train_augm = datasets.load_dataset("gaussalgo/Canard_Wiki-augmented", split="train")
canard_test_augm = datasets.load_dataset("gaussalgo/Canard_Wiki-augmented", split="test")

canard_df = canard_train_augm.to_pandas()
canard_test_df = canard_train_augm.to_pandas()


### Curation of seq2seq input contexts and labels
import random

def input_context_from_sample(row: dict, max_length=5) -> str:
    context = "Previous conversation:"
    context += "\nQuestion: "
    context += ", ".join(row["History"][:3])
    for i in range(3, len(row["History"]), 2):
        context += "\nAnswer: "
        context += row["History"][i]
        if i+1 < len(row["History"]):
            context += "\nQuestion: "
            context += row["History"][i+1]

    context += "\n\nCurrent Question: "
    context += row["Question"]

    context += "\nSearch results:"
    all_contexts = row["retrieved_contexts"].tolist()[:max_length-1] + [row["true_contexts"]]
    random.shuffle(all_contexts)

    for i, search_result in enumerate(all_contexts):
        context += "\n[%s]: " % (i+1)
        context += search_result.replace("CANNOTANSWER", "")

    context += "\nCurrent Answer: "
    return context

def rephrasing_context_from_sample(row: dict) -> str:
    context = "Previous conversation:"
    context += "\nQuestion: "
    context += ", ".join(row["History"][:3])
    for i in range(3, len(row["History"]), 2):
        context += "\nAnswer: "
        context += row["History"][i]
        if i+1 < len(row["History"]):
            context += "\nQuestion: "
            context += row["History"][i+1]
    
    context += "\n\nCurrent Question: "
    context += row["Question"]

    context += "\nMore specific question: "
    return context

def hotpotqa_context(row: dict) -> str:
    context = "Current Question: "
    context += row["question"]

    context += "\nSearch results:"
    all_contexts = [" ".join(context) for context in row["context"]["sentences"]]

    for i, search_result in enumerate(all_contexts):
        context += "\n[%s]: " % (i+1)
        context += search_result.replace("CANNOTANSWER", "")

    context += "\nCurrent Answer: "
    return context

# Conversational QA sequences
input_texts = canard_df.apply(lambda row: input_context_from_sample(row), axis=1).values
input_val_texts = canard_test_df.iloc[:200].apply(lambda row: input_context_from_sample(row), axis=1).values

too_long_index = [len(t) > 20000 for t in input_texts]
input_texts = [t for i, t in enumerate(input_texts) if not too_long_index[i]]
# print(too_long_index)
print("training on %s samples" % len(input_texts))

labels = canard_df.answer.apply(lambda ans: "No answer" if ans == "CANNOTANSWER" else ans).values
labels = [l for i, l in enumerate(labels)  if not too_long_index[i]]
val_labels = canard_test_df.answer.apply(lambda ans: "No answer" if ans == "CANNOTANSWER" else ans).values

# Rephrasing sequences
rephrasing_inputs = canard_df.apply(lambda row: rephrasing_context_from_sample(row), axis=1).values
rephrasing_val_inputs = canard_test_df.apply(lambda row: rephrasing_context_from_sample(row), axis=1).values

rephrasing_labels = canard_df.Rewrite.values
rephrasing_val_labels = canard_test_df.Rewrite.values

# HotpotQA sequences
hotpot_train = datasets.load_dataset("hotpot_qa", "distractor")["train"]
hotpot_val = datasets.load_dataset("hotpot_qa", "distractor")["validation"]

hotpot_inputs = hotpot_train.to_pandas().apply(hotpotqa_context, axis=1)
hotpot_val_inputs = hotpot_val.to_pandas().apply(hotpotqa_context, axis=1)
too_long_index = [len(t) > 20000 for t in hotpot_inputs]

hotpot_inputs = [t for i, t in enumerate(hotpot_inputs) if not too_long_index[i]]
hotpot_answers = [t for i, t in enumerate(hotpot_train["answer"]) if not too_long_index[i]]

# Training routine
# see Adaptor's homepage for details:
# https://github.com/gaussalgo/adaptor

# Base model
from adaptor.lang_module import LangModule
lang_module = LangModule("google/t5-large-lm-adapt")

from adaptor.evaluators.generative import ROUGE, BLEU

# Evaluations
evaluators = [BLEU(), ROUGE(decides_convergence=True)]

# Objectives
from adaptor.objectives.seq2seq import Sequence2Sequence

seq_qa = Sequence2Sequence(lang_module,
                           texts_or_path=input_texts,
                           labels_or_path=labels,
                           val_texts_or_path=input_val_texts,
                           val_labels_or_path=val_labels,
                           batch_size=4,
                           val_evaluators=evaluators,
                           objective_id="Canard")

seq_additional_qa = Sequence2Sequence(lang_module,
                                      texts_or_path=hotpot_inputs,
                                      labels_or_path=hotpot_answers,
                                      val_texts_or_path=hotpot_val_inputs[:200],
                                      val_labels_or_path=hotpot_val["answer"][:200],
                                      batch_size=4,
                                      val_evaluators=evaluators,
                                      objective_id="HotpotQA",
                                      share_other_objective_head=seq_qa)

seq_rephrasing = Sequence2Sequence(lang_module,
                                   texts_or_path=rephrasing_inputs,
                                   labels_or_path=rephrasing_labels,
                                   val_texts_or_path=rephrasing_val_inputs[:200],
                                   val_labels_or_path=rephrasing_val_labels[:200],
                                   batch_size=4,
                                   val_evaluators=evaluators,
                                   objective_id="rephrasing",
                                   share_other_objective_head=seq_qa)

# Training schedule & arguments
from adaptor.utils import AdaptationArguments, StoppingStrategy

training_arguments = AdaptationArguments(output_dir="checkpoints-chatbot",
                                         learning_rate=5e-5,
                                         stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED,
                                         stopping_patience=8,
                                         save_total_limit=8,
                                         do_train=True,
                                         do_eval=True,
                                         bf16=True,
                                         warmup_steps=1000,
                                         gradient_accumulation_steps=8,
                                         logging_steps=10,
                                         eval_steps=200,
                                         save_steps=1000,
                                         num_train_epochs=10,
                                         evaluation_strategy="steps")
from adaptor.schedules import ParallelSchedule
from adaptor.adapter import Adapter

schedule = ParallelSchedule(objectives=[seq_qa, seq_additional_qa, seq_rephrasing],
                            args=training_arguments)
adapter = Adapter(lang_module, schedule, args=training_arguments)
adapter.train()  # Training for 63k updates
```

## Usage
See the prompting templates used in training to infer the optimal prompting format.

#### Contact
Feel free to ask questions here, or at stefanik{at} gaussalgo.com