--- datasets: - hotpot_qa - gaussalgo/Canard_Wiki-augmented --- # Model Card for T5-LM-Large_Canard-HotpotQA-rephrase This model is trained on three objectives: (1) Generating answers for Canard dataset, (2) Generating answers for HotpotQA, (3) Rephrasing questions by the previous conversations of Canard. ## Training The model was trained using the following script, exported from the corresponding Jupyter notebook. All details, including the request format, can be inferred without errors from the code. The best checkpoint was picked by a minimal loss on all (3) training objectives. ```python import datasets canard_train_augm = datasets.load_dataset("gaussalgo/Canard_Wiki-augmented", split="train") # see the dataset card for details canard_test_augm = datasets.load_dataset("gaussalgo/Canard_Wiki-augmented", split="test") canard_df = canard_train_augm.to_pandas() canard_test_df = canard_train_augm.to_pandas() ### Curation of seq2seq input contexts and labels import random def input_context_from_sample(row: dict, max_length=5) -> str: context = "Previous conversation:" context += "\nQuestion: " context += ", ".join(row["History"][:3]) for i in range(3, len(row["History"]), 2): context += "\nAnswer: " context += row["History"][i] if i+1 < len(row["History"]): context += "\nQuestion: " context += row["History"][i+1] context += "\n\nCurrent Question: " context += row["Question"] context += "\nSearch results:" all_contexts = row["retrieved_contexts"].tolist()[:max_length-1] + [row["true_contexts"]] random.shuffle(all_contexts) for i, search_result in enumerate(all_contexts): context += "\n[%s]: " % (i+1) context += search_result.replace("CANNOTANSWER", "") context += "\nCurrent Answer: " return context def rephrasing_context_from_sample(row: dict) -> str: context = "Previous conversation:" context += "\nQuestion: " context += ", ".join(row["History"][:3]) for i in range(3, len(row["History"]), 2): context += "\nAnswer: " context += row["History"][i] if i+1 < len(row["History"]): context += "\nQuestion: " context += row["History"][i+1] context += "\n\nCurrent Question: " context += row["Question"] context += "\nMore specific question: " return context def hotpotqa_context(row: dict) -> str: context = "Current Question: " context += row["question"] context += "\nSearch results:" all_contexts = [" ".join(context) for context in row["context"]["sentences"]] for i, search_result in enumerate(all_contexts): context += "\n[%s]: " % (i+1) # context += search_result.replace("CANNOTANSWER", "") context += "\nCurrent Answer: " return context input_texts = canard_df.apply(lambda row: input_context_from_sample(row), axis=1).values input_val_texts = canard_test_df.iloc[:200].apply(lambda row: input_context_from_sample(row), axis=1).values too_long_index = [len(t) > 20000 for t in input_texts] input_texts = [t for i, t in enumerate(input_texts) if not too_long_index[i]] print("training on %s samples" % len(input_texts)) labels = canard_df.answer.apply(lambda ans: "No answer" if ans == "CANNOTANSWER" else ans).values labels = [l for i, l in enumerate(labels) if not too_long_index[i]] val_labels = canard_test_df.answer.apply(lambda ans: "No answer" if ans == "CANNOTANSWER" else ans).values rephrasing_inputs = canard_df.apply(lambda row: rephrasing_context_from_sample(row), axis=1).values print(rephrasing_inputs[0]) rephrasing_val_inputs = canard_test_df.apply(lambda row: rephrasing_context_from_sample(row), axis=1).values rephrasing_labels = canard_df.Rewrite.values rephrasing_val_labels = canard_test_df.Rewrite.values print(rephrasing_labels[0]) # Training # see Adaptor's homepage for details: # https://github.com/gaussalgo/adaptor from adaptor.lang_module import LangModule lang_module = LangModule("google/t5-large-lm-adapt") from adaptor.evaluators.generative import ROUGE, BLEU evaluators = [BLEU(), ROUGE()] from adaptor.objectives.seq2seq import Sequence2Sequence seq_qa = Sequence2Sequence(lang_module, texts_or_path=input_texts, labels_or_path=labels, val_texts_or_path=input_val_texts, val_labels_or_path=val_labels, batch_size=4, val_evaluators=evaluators, objective_id="Canard") hotpot_train = datasets.load_dataset("hotpot_qa", "distractor")["train"] hotpot_val = datasets.load_dataset("hotpot_qa", "distractor")["validation"] hotpot_inputs = hotpot_train.to_pandas().apply(hotpotqa_context, axis=1) hotpot_val_inputs = hotpot_val.to_pandas().apply(hotpotqa_context, axis=1) too_long_index = [len(t) > 20000 for t in hotpot_inputs] hotpot_inputs = [t for i, t in enumerate(hotpot_inputs) if not too_long_index[i]] hotpot_answers = [t for i, t in enumerate(hotpot_train["answer"]) if not too_long_index[i]] seq_additional_qa = Sequence2Sequence(lang_module, texts_or_path=hotpot_inputs, labels_or_path=hotpot_answers, val_texts_or_path=hotpot_val_inputs[:200], val_labels_or_path=hotpot_val["answer"][:200], batch_size=4, val_evaluators=evaluators, objective_id="HotpotQA", share_other_objective_head=seq_qa) seq_rephrasing = Sequence2Sequence(lang_module, texts_or_path=rephrasing_inputs, labels_or_path=rephrasing_labels, val_texts_or_path=rephrasing_val_inputs[:200], val_labels_or_path=rephrasing_val_labels[:200], batch_size=4, val_evaluators=evaluators, objective_id="rephrasing", share_other_objective_head=seq_qa) from adaptor.utils import AdaptationArguments, StoppingStrategy training_arguments = AdaptationArguments(output_dir="checkpoints-chatbot", learning_rate=5e-5, stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED, stopping_patience=8, save_total_limit=8, do_train=True, do_eval=True, bf16=True, warmup_steps=1000, gradient_accumulation_steps=8, logging_steps=10, eval_steps=200, save_steps=1000, num_train_epochs=10, evaluation_strategy="steps") from adaptor.schedules import ParallelSchedule from adaptor.adapter import Adapter schedule = ParallelSchedule(objectives=[seq_qa, seq_additional_qa, seq_rephrasing], args=training_arguments) adapter = Adapter(lang_module, schedule, args=training_arguments) adapter.train() ``` ## Usage See the prompting templates used in training to infer the optimal prompting format. #### Contact Feel free to ask questions here, or at stefanik{at} gaussalgo.com