michal-stefanik commited on
Commit
bad0757
1 Parent(s): 8ef83e6

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +225 -0
README.md ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ datasets:
3
+ - gaussalgo/Canard_Wiki-augmented
4
+ - hotpot_qa
5
+ metrics:
6
+ - rouge
7
+ - bleu
8
+ model-index:
9
+ - name: T5-LM-Large_Canard-Fullwiki-HotpotQA-rephrase
10
+ results:
11
+ - task:
12
+ type: question-answering
13
+ name: Question Answering
14
+ dataset:
15
+ type: hotpot_qa
16
+ name: HotpotQA
17
+ split: validation
18
+ metrics:
19
+ - type: rouge
20
+ value: 0.4774
21
+ - type: bleu
22
+ value: 29.11
23
+ - task:
24
+ type: question-answering
25
+ name: Question Answering
26
+ dataset:
27
+ type: gaussalgo/Canard_Wiki-augmented
28
+ name: Wikipedia-augmented Conversational QA (Canard)
29
+ split: validation
30
+ metrics:
31
+ - type: rouge
32
+ value: 0.4377
33
+ - type: bleu
34
+ value: 19.34
35
+ license: cc-by-sa-4.0
36
+ language:
37
+ - en
38
+ ---
39
+
40
+ # Model Card for T5-LM-Large_Canard-HotpotQA-rephrase
41
+ This model is trained on three objectives:
42
+ (1) Generating answers for Canard dataset based on Wikipedia search results
43
+ (2) Generating answers for HotpotQA,
44
+ (3) Rephrasing questions by the conversation context.
45
+
46
+ ## Training
47
+ The model was trained using the following script, exported from the corresponding Jupyter notebook. All details, including the request format, can be inferred without errors from the code.
48
+ The best checkpoint was picked by a maximum ROUGE on Canard conversational QA's ROUGE.
49
+ ```python
50
+ import datasets
51
+
52
+ canard_train_augm = datasets.load_dataset("gaussalgo/Canard_Wiki-augmented", split="train")
53
+ canard_test_augm = datasets.load_dataset("gaussalgo/Canard_Wiki-augmented", split="test")
54
+
55
+ canard_df = canard_train_augm.to_pandas()
56
+ canard_test_df = canard_train_augm.to_pandas()
57
+
58
+
59
+ ### Curation of seq2seq input contexts and labels
60
+ import random
61
+
62
+ def input_context_from_sample(row: dict, max_length=5) -> str:
63
+ context = "Previous conversation:"
64
+ context += "\nQuestion: "
65
+ context += ", ".join(row["History"][:3])
66
+ for i in range(3, len(row["History"]), 2):
67
+ context += "\nAnswer: "
68
+ context += row["History"][i]
69
+ if i+1 < len(row["History"]):
70
+ context += "\nQuestion: "
71
+ context += row["History"][i+1]
72
+
73
+ context += "\n\nCurrent Question: "
74
+ context += row["Question"]
75
+
76
+ context += "\nSearch results:"
77
+ all_contexts = row["retrieved_contexts"].tolist()[:max_length-1] + [row["true_contexts"]]
78
+ random.shuffle(all_contexts)
79
+
80
+ for i, search_result in enumerate(all_contexts):
81
+ context += "\n[%s]: " % (i+1)
82
+ context += search_result.replace("CANNOTANSWER", "")
83
+
84
+ context += "\nCurrent Answer: "
85
+ return context
86
+
87
+ def rephrasing_context_from_sample(row: dict) -> str:
88
+ context = "Previous conversation:"
89
+ context += "\nQuestion: "
90
+ context += ", ".join(row["History"][:3])
91
+ for i in range(3, len(row["History"]), 2):
92
+ context += "\nAnswer: "
93
+ context += row["History"][i]
94
+ if i+1 < len(row["History"]):
95
+ context += "\nQuestion: "
96
+ context += row["History"][i+1]
97
+
98
+ context += "\n\nCurrent Question: "
99
+ context += row["Question"]
100
+
101
+ context += "\nMore specific question: "
102
+ return context
103
+
104
+ def hotpotqa_context(row: dict) -> str:
105
+ context = "Current Question: "
106
+ context += row["question"]
107
+
108
+ context += "\nSearch results:"
109
+ all_contexts = [" ".join(context) for context in row["context"]["sentences"]]
110
+
111
+ for i, search_result in enumerate(all_contexts):
112
+ context += "\n[%s]: " % (i+1)
113
+ context += search_result.replace("CANNOTANSWER", "")
114
+
115
+ context += "\nCurrent Answer: "
116
+ return context
117
+
118
+ # Conversational QA sequences
119
+ input_texts = canard_df.apply(lambda row: input_context_from_sample(row), axis=1).values
120
+ input_val_texts = canard_test_df.iloc[:200].apply(lambda row: input_context_from_sample(row), axis=1).values
121
+
122
+ too_long_index = [len(t) > 20000 for t in input_texts]
123
+ input_texts = [t for i, t in enumerate(input_texts) if not too_long_index[i]]
124
+ # print(too_long_index)
125
+ print("training on %s samples" % len(input_texts))
126
+
127
+ labels = canard_df.answer.apply(lambda ans: "No answer" if ans == "CANNOTANSWER" else ans).values
128
+ labels = [l for i, l in enumerate(labels) if not too_long_index[i]]
129
+ val_labels = canard_test_df.answer.apply(lambda ans: "No answer" if ans == "CANNOTANSWER" else ans).values
130
+
131
+ # Rephrasing sequences
132
+ rephrasing_inputs = canard_df.apply(lambda row: rephrasing_context_from_sample(row), axis=1).values
133
+ rephrasing_val_inputs = canard_test_df.apply(lambda row: rephrasing_context_from_sample(row), axis=1).values
134
+
135
+ rephrasing_labels = canard_df.Rewrite.values
136
+ rephrasing_val_labels = canard_test_df.Rewrite.values
137
+
138
+ # HotpotQA sequences
139
+ hotpot_train = datasets.load_dataset("hotpot_qa", "distractor")["train"]
140
+ hotpot_val = datasets.load_dataset("hotpot_qa", "distractor")["validation"]
141
+
142
+ hotpot_inputs = hotpot_train.to_pandas().apply(hotpotqa_context, axis=1)
143
+ hotpot_val_inputs = hotpot_val.to_pandas().apply(hotpotqa_context, axis=1)
144
+ too_long_index = [len(t) > 20000 for t in hotpot_inputs]
145
+
146
+ hotpot_inputs = [t for i, t in enumerate(hotpot_inputs) if not too_long_index[i]]
147
+ hotpot_answers = [t for i, t in enumerate(hotpot_train["answer"]) if not too_long_index[i]]
148
+
149
+ # Training routine
150
+ # see Adaptor's homepage for details:
151
+ # https://github.com/gaussalgo/adaptor
152
+
153
+ # Base model
154
+ from adaptor.lang_module import LangModule
155
+ lang_module = LangModule("google/t5-large-lm-adapt")
156
+
157
+ from adaptor.evaluators.generative import ROUGE, BLEU
158
+
159
+ # Evaluations
160
+ evaluators = [BLEU(), ROUGE(decides_convergence=True)]
161
+
162
+ # Objectives
163
+ from adaptor.objectives.seq2seq import Sequence2Sequence
164
+
165
+ seq_qa = Sequence2Sequence(lang_module,
166
+ texts_or_path=input_texts,
167
+ labels_or_path=labels,
168
+ val_texts_or_path=input_val_texts,
169
+ val_labels_or_path=val_labels,
170
+ batch_size=4,
171
+ val_evaluators=evaluators,
172
+ objective_id="Canard")
173
+
174
+ seq_additional_qa = Sequence2Sequence(lang_module,
175
+ texts_or_path=hotpot_inputs,
176
+ labels_or_path=hotpot_answers,
177
+ val_texts_or_path=hotpot_val_inputs[:200],
178
+ val_labels_or_path=hotpot_val["answer"][:200],
179
+ batch_size=4,
180
+ val_evaluators=evaluators,
181
+ objective_id="HotpotQA",
182
+ share_other_objective_head=seq_qa)
183
+
184
+ seq_rephrasing = Sequence2Sequence(lang_module,
185
+ texts_or_path=rephrasing_inputs,
186
+ labels_or_path=rephrasing_labels,
187
+ val_texts_or_path=rephrasing_val_inputs[:200],
188
+ val_labels_or_path=rephrasing_val_labels[:200],
189
+ batch_size=4,
190
+ val_evaluators=evaluators,
191
+ objective_id="rephrasing",
192
+ share_other_objective_head=seq_qa)
193
+
194
+ # Training schedule & arguments
195
+ from adaptor.utils import AdaptationArguments, StoppingStrategy
196
+
197
+ training_arguments = AdaptationArguments(output_dir="checkpoints-chatbot",
198
+ learning_rate=5e-5,
199
+ stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED,
200
+ stopping_patience=8,
201
+ save_total_limit=8,
202
+ do_train=True,
203
+ do_eval=True,
204
+ bf16=True,
205
+ warmup_steps=1000,
206
+ gradient_accumulation_steps=8,
207
+ logging_steps=10,
208
+ eval_steps=200,
209
+ save_steps=1000,
210
+ num_train_epochs=10,
211
+ evaluation_strategy="steps")
212
+ from adaptor.schedules import ParallelSchedule
213
+ from adaptor.adapter import Adapter
214
+
215
+ schedule = ParallelSchedule(objectives=[seq_qa, seq_additional_qa, seq_rephrasing],
216
+ args=training_arguments)
217
+ adapter = Adapter(lang_module, schedule, args=training_arguments)
218
+ adapter.train() # Training for 63k updates
219
+ ```
220
+
221
+ ## Usage
222
+ See the prompting templates used in training to infer the optimal prompting format.
223
+
224
+ #### Contact
225
+ Feel free to ask questions here, or at stefanik{at} gaussalgo.com