--- datasets: - gaussalgo/Canard_Wiki-augmented - hotpot_qa metrics: - rouge - bleu model-index: - name: T5-LM-Large_Canard-Fullwiki-HotpotQA-rephrase results: - task: type: question-answering name: Question Answering dataset: type: hotpot_qa name: HotpotQA split: validation metrics: - type: rouge value: 0.4774 - type: bleu value: 29.11 - task: type: question-answering name: Question Answering dataset: type: gaussalgo/Canard_Wiki-augmented name: Wikipedia-augmented Conversational QA (Canard) split: validation metrics: - type: rouge value: 0.4377 - type: bleu value: 19.34 license: cc-by-sa-4.0 language: - en --- # Model Card for T5-LM-Large_Canard-HotpotQA-rephrase This model is trained on three objectives: 1. Generating answers for Canard dataset based on Wikipedia search results 2. Generating answers for HotpotQA, 3. Rephrasing questions by the conversation context. ## Training The model was trained using the following script, which can be copy-pasted and run as-is (with the installed `requirements.txt`). All details, including the request format, can be inferred without errors from the code. The best checkpoint was picked by a maximum ROUGE on Canard conversational QA's ROUGE. ```python import datasets canard_train_augm = datasets.load_dataset("gaussalgo/Canard_Wiki-augmented", split="train") canard_test_augm = datasets.load_dataset("gaussalgo/Canard_Wiki-augmented", split="test") canard_df = canard_train_augm.to_pandas() canard_test_df = canard_train_augm.to_pandas() ### Curation of seq2seq input contexts and labels import random def input_context_from_sample(row: dict, max_length=5) -> str: context = "Previous conversation:" context += "\nQuestion: " context += ", ".join(row["History"][:3]) for i in range(3, len(row["History"]), 2): context += "\nAnswer: " context += row["History"][i] if i+1 < len(row["History"]): context += "\nQuestion: " context += row["History"][i+1] context += "\n\nCurrent Question: " context += row["Question"] context += "\nSearch results:" all_contexts = row["retrieved_contexts"].tolist()[:max_length-1] + [row["true_contexts"]] random.shuffle(all_contexts) for i, search_result in enumerate(all_contexts): context += "\n[%s]: " % (i+1) context += search_result.replace("CANNOTANSWER", "") context += "\nCurrent Answer: " return context def rephrasing_context_from_sample(row: dict) -> str: context = "Previous conversation:" context += "\nQuestion: " context += ", ".join(row["History"][:3]) for i in range(3, len(row["History"]), 2): context += "\nAnswer: " context += row["History"][i] if i+1 < len(row["History"]): context += "\nQuestion: " context += row["History"][i+1] context += "\n\nCurrent Question: " context += row["Question"] context += "\nMore specific question: " return context def hotpotqa_context(row: dict) -> str: context = "Current Question: " context += row["question"] context += "\nSearch results:" all_contexts = [" ".join(context) for context in row["context"]["sentences"]] for i, search_result in enumerate(all_contexts): context += "\n[%s]: " % (i+1) context += search_result.replace("CANNOTANSWER", "") context += "\nCurrent Answer: " return context # Conversational QA sequences input_texts = canard_df.apply(lambda row: input_context_from_sample(row), axis=1).values input_val_texts = canard_test_df.iloc[:200].apply(lambda row: input_context_from_sample(row), axis=1).values too_long_index = [len(t) > 20000 for t in input_texts] input_texts = [t for i, t in enumerate(input_texts) if not too_long_index[i]] # print(too_long_index) print("training on %s samples" % len(input_texts)) labels = canard_df.answer.apply(lambda ans: "No answer" if ans == "CANNOTANSWER" else ans).values labels = [l for i, l in enumerate(labels) if not too_long_index[i]] val_labels = canard_test_df.answer.apply(lambda ans: "No answer" if ans == "CANNOTANSWER" else ans).values # Rephrasing sequences rephrasing_inputs = canard_df.apply(lambda row: rephrasing_context_from_sample(row), axis=1).values rephrasing_val_inputs = canard_test_df.apply(lambda row: rephrasing_context_from_sample(row), axis=1).values rephrasing_labels = canard_df.Rewrite.values rephrasing_val_labels = canard_test_df.Rewrite.values # HotpotQA sequences hotpot_train = datasets.load_dataset("hotpot_qa", "distractor")["train"] hotpot_val = datasets.load_dataset("hotpot_qa", "distractor")["validation"] hotpot_inputs = hotpot_train.to_pandas().apply(hotpotqa_context, axis=1) hotpot_val_inputs = hotpot_val.to_pandas().apply(hotpotqa_context, axis=1) too_long_index = [len(t) > 20000 for t in hotpot_inputs] hotpot_inputs = [t for i, t in enumerate(hotpot_inputs) if not too_long_index[i]] hotpot_answers = [t for i, t in enumerate(hotpot_train["answer"]) if not too_long_index[i]] # Training routine # see Adaptor's homepage for details: # https://github.com/gaussalgo/adaptor # Base model from adaptor.lang_module import LangModule lang_module = LangModule("google/t5-large-lm-adapt") from adaptor.evaluators.generative import ROUGE, BLEU # Evaluations evaluators = [BLEU(), ROUGE(decides_convergence=True)] # Objectives from adaptor.objectives.seq2seq import Sequence2Sequence seq_qa = Sequence2Sequence(lang_module, texts_or_path=input_texts, labels_or_path=labels, val_texts_or_path=input_val_texts, val_labels_or_path=val_labels, batch_size=4, val_evaluators=evaluators, objective_id="Canard") seq_additional_qa = Sequence2Sequence(lang_module, texts_or_path=hotpot_inputs, labels_or_path=hotpot_answers, val_texts_or_path=hotpot_val_inputs[:200], val_labels_or_path=hotpot_val["answer"][:200], batch_size=4, val_evaluators=evaluators, objective_id="HotpotQA", share_other_objective_head=seq_qa) seq_rephrasing = Sequence2Sequence(lang_module, texts_or_path=rephrasing_inputs, labels_or_path=rephrasing_labels, val_texts_or_path=rephrasing_val_inputs[:200], val_labels_or_path=rephrasing_val_labels[:200], batch_size=4, val_evaluators=evaluators, objective_id="rephrasing", share_other_objective_head=seq_qa) # Training schedule & arguments from adaptor.utils import AdaptationArguments, StoppingStrategy training_arguments = AdaptationArguments(output_dir="checkpoints-chatbot", learning_rate=5e-5, stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED, stopping_patience=8, save_total_limit=8, do_train=True, do_eval=True, bf16=True, warmup_steps=1000, gradient_accumulation_steps=8, logging_steps=10, eval_steps=200, save_steps=1000, num_train_epochs=10, evaluation_strategy="steps") from adaptor.schedules import ParallelSchedule from adaptor.adapter import Adapter schedule = ParallelSchedule(objectives=[seq_qa, seq_additional_qa, seq_rephrasing], args=training_arguments) adapter = Adapter(lang_module, schedule, args=training_arguments) adapter.train() # Training for 63k updates ``` ## Usage See the prompting templates used in training to infer the optimal prompting format. #### Contact Feel free to ask questions here, or at stefanik{at} gaussalgo.com