michal-stefanik commited on
Commit
32980f3
1 Parent(s): 966e6f1

README & training scripts

Browse files
README.md ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ datasets:
4
+ - sberquad
5
+ - adversarial_qa
6
+ language:
7
+ - en
8
+ - ru
9
+ metrics:
10
+ - rouge
11
+ pipeline_tag: text2text-generation
12
+ ---
13
+
14
+ # Model Card for mTk-QA_SQuAD-en_SQAD_ru
15
+ This model is a generative in-context few-shot learner specialized in Russian. It was trained on a combination of English AdversarialQA and Russian SberQuAD datasets.
16
+
17
+ You can find detailed information on [Project Github](https://github.com/fewshot-goes-multilingual/slavic-incontext-learning) & the referenced paper.
18
+
19
+ ## Model Details
20
+ ### Model Description
21
+ - **Developed by:** [To Be Filled]
22
+ - **Model type:** [mt5]
23
+ - **Language(s) (NLP):** cs,ru
24
+ - **License:** MIT
25
+ - **Finetuned from model [optional]:** google/mt5-large
26
+ ### Model Sources
27
+ - **Repository:** [https://github.com/fewshot-goes-multilingual/slavic-incontext-learning]
28
+ - **Paper:** [To be filled]
29
+ ## Uses
30
+ This model is intended to be used in a few-shot in-context learning format in the target language (Russian), or in the source language (English, see below).
31
+ It was evaluated for unseen task learning (with k=3 demonstrations) in Russian: see the referenced paper for details.
32
+ ### How to Get Started with the Model
33
+ Use the code below to get started with the model.
34
+ ```python
35
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
36
+ model = AutoModelForSeq2SeqLM.from_pretrained("{this model path}")
37
+ tokenizer = AutoTokenizer.from_pretrained("{this model path}")
38
+ # Instead, use keywords "Вопрос", "Контекст" and "Отвечать" for Russian few-shot prompts
39
+ input_text = """
40
+ Question: What is the customer's name?
41
+ Context: Origin: Barrack Obama, Customer id: Bill Moe.
42
+ Answer: Bill Moe,
43
+ Question: What is the customer's name?
44
+ Context: Customer id: Barrack Obama, if not deliverable, return to Bill Clinton.
45
+ Answer:
46
+ """
47
+ inputs = tokenizer(input_text, return_tensors="pt")
48
+
49
+ outputs = model.generate(**inputs)
50
+ print("Answer:")
51
+ print(tokenizer.decode(outputs))
52
+ ```
53
+ ## Training Details
54
+ Training this model can be reproduced by running `pip install -r requirements.txt && python train_mt5_qa_en_AQA+ru_info.py
55
+ `.
56
+ See the referenced script for hyperparameters and other training configurations.
57
+ ## Citation
58
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
59
+ **BibTeX:**
60
+ [Will be filled soon]
priming_objective.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+ from typing import Iterable, Union, Dict, List, Optional
4
+
5
+ import torch
6
+ from adaptor.objectives.seq2seq import Sequence2Sequence
7
+ from transformers import BatchEncoding
8
+
9
+ logger = logging.getLogger()
10
+
11
+ priming_formats = {
12
+ "QA": {"cs": "Otázka: %s Kontext: %s Odpověď:",
13
+ "en": "Question: %s Context: %s Answer:",
14
+ "ru": "Вопрос: %s Контекст: %s Отвечать:"}}
15
+
16
+
17
+ class Priming(Sequence2Sequence):
18
+
19
+ def __init__(self, *args,
20
+ train_question_categories: Iterable[str],
21
+ max_eval_samples: int,
22
+ val_question_categories: Optional[Iterable[str]] = None,
23
+ min_num_demonstrations: int = 2,
24
+ max_num_demonstrations: int = 5,
25
+ demos_infer_batch_size: int = 32,
26
+ demos_selection_strategy: str = "hard",
27
+ difficulty_sample: int = 64,
28
+ max_input_length: int = 8000,
29
+ **kwargs):
30
+ super().__init__(*args, **kwargs)
31
+
32
+ self.train_question_categories = list(train_question_categories)
33
+ self.val_question_categories = list(val_question_categories) if val_question_categories is not None else None
34
+
35
+ self.min_num_demonstrations = min_num_demonstrations
36
+ self.max_num_demonstrations = max_num_demonstrations
37
+ self.demos_infer_batch_size = demos_infer_batch_size
38
+ self.demos_selection_strategy = demos_selection_strategy
39
+ self.difficulty_sample = difficulty_sample
40
+ self.max_input_length = max_input_length
41
+ self.max_eval_samples = max_eval_samples
42
+
43
+ def _construct_qa_prompt(self, question: str, context: str) -> str:
44
+ return priming_formats["QA"][self.source_lang_id] % (question, context)
45
+
46
+ def _construct_demonstration(self, prompt: str, answer: str) -> str:
47
+ return "%s %s " % (prompt, answer)
48
+
49
+ def _construct_primed_prompt(self, primed_demonstrations: List[str], prompt: str) -> str:
50
+ return " ".join(primed_demonstrations) + " " + prompt
51
+
52
+ def forced_generation_score(self, input_texts: List[str], forced_output: str) -> torch.FloatTensor:
53
+ inputs = self.tokenizer(input_texts, return_tensors="pt", padding="longest", truncation=True)
54
+ inputs = inputs.to(self.compatible_head_model.device)
55
+
56
+ with self.tokenizer.as_target_tokenizer():
57
+ output_ids = self.tokenizer(forced_output, return_tensors="pt", padding="longest",
58
+ truncation=True).input_ids.to(self.compatible_head_model.device)
59
+ forced_outputs = self.compatible_head_model.prepare_decoder_input_ids_from_labels(output_ids)
60
+ forced_outputs = forced_outputs.to(self.compatible_head_model.device)
61
+
62
+ outputs = self.compatible_head_model(**inputs,
63
+ decoder_input_ids=forced_outputs.expand(inputs.input_ids.shape[0], -1))
64
+ output_log_probs = outputs.logits.log_softmax(-1)
65
+ forced_output_logits = torch.gather(output_log_probs, -1,
66
+ output_ids.expand(inputs.input_ids.shape[0], -1).unsqueeze(-1))
67
+ forced_output_log_score = forced_output_logits.sum((-1, -2))
68
+ # we do not need to normalize, as all the targets are the same <=> same length
69
+ return forced_output_log_score.double().exp()
70
+
71
+ def _pick_most_difficult_demo(self,
72
+ selected_demos: List[str],
73
+ next_demo_cands: List[str],
74
+ predict_prompt: str,
75
+ predicted_answer: str) -> int:
76
+ with torch.no_grad():
77
+ difficulties = torch.empty(0, device=self.compatible_head_model.device, dtype=torch.float)
78
+
79
+ for batch_offset in range(0, len(next_demo_cands), self.demos_infer_batch_size):
80
+ next_demo_cands_batch = next_demo_cands[batch_offset: batch_offset + self.demos_infer_batch_size]
81
+
82
+ primed_prompts = [self._construct_primed_prompt(selected_demos + [demo], predict_prompt)
83
+ for demo in next_demo_cands_batch]
84
+ cands_difficulty = self.forced_generation_score(primed_prompts, predicted_answer)
85
+
86
+ difficulties = torch.hstack((difficulties, cands_difficulty))
87
+
88
+ assert difficulties.argmin() < len(next_demo_cands)
89
+
90
+ return difficulties.argmin()
91
+
92
+ def _get_inputs_iterator(self, split: str) -> Iterable[Union[BatchEncoding, Dict[str, torch.Tensor]]]:
93
+ """
94
+ Creates a default iterator over encodings with aligned input and output texts.
95
+ :param split: Data split. `train` or `eval`.
96
+ :return: Iterator of model input encodings.
97
+ """
98
+ # we materialize all samples in memory, so that we can heuristically pick the combinations
99
+ questions, contexts, answers = (list(it) for it in self._per_split_iterators(split))
100
+ question_categories = self.train_question_categories if split == "train" else self.val_question_categories
101
+
102
+ assert len(questions) == len(contexts) == len(answers) == len(question_categories), \
103
+ "Given numbers of questions, contexts and answers do not match."
104
+
105
+ prompts = [self._construct_qa_prompt(q, c) for q, c in zip(questions, contexts)]
106
+
107
+ features_batch = []
108
+ cat_index = {cat: [i for i, sample_cat in enumerate(question_categories) if cat == sample_cat]
109
+ for cat in set(question_categories)}
110
+
111
+ retrieved_samples = 0
112
+
113
+ for idx, sample_category in enumerate(question_categories):
114
+ if not cat_index[sample_category]:
115
+ logger.warning("No samples within the category %s", sample_category)
116
+ continue
117
+
118
+ pred_prompt, pred_answer = prompts[idx], answers[idx]
119
+
120
+ picked_demonstrations = []
121
+
122
+ # a number of demonstrations is in the specified range
123
+ expected_num_demonstrations = random.randint(self.min_num_demonstrations, self.max_num_demonstrations)
124
+
125
+ while len(picked_demonstrations) < expected_num_demonstrations:
126
+ if sum(map(len, picked_demonstrations)) > self.max_input_length:
127
+ logger.warning("Skipping too long prompt.")
128
+ break
129
+ if self.demos_selection_strategy == "hard":
130
+ # pick the most difficult examples out of a sample
131
+ # we do not need to worry for picking up the predicted sample among demonstrations in hard strategy
132
+ if len(cat_index[sample_category]) <= 1:
133
+ # we can not construct informative demonstrations for categories of a single item
134
+ break
135
+
136
+ samples_idx = random.choices(cat_index[sample_category], k=self.difficulty_sample)
137
+ cand_demonstrations = [self._construct_demonstration(prompts[i], answers[i]) for i in samples_idx]
138
+ selected_index = self._pick_most_difficult_demo(picked_demonstrations, cand_demonstrations,
139
+ pred_prompt, pred_answer)
140
+ picked_demonstrations.append(cand_demonstrations[selected_index])
141
+ elif self.demos_selection_strategy == "informative":
142
+ if len(cat_index[sample_category]) <= 1:
143
+ # we can not construct informative demonstrations for categories of a single item
144
+ break
145
+ selected_cat_index = random.randint(1, len(cat_index[sample_category])-1)
146
+ selected_index = cat_index[sample_category][selected_cat_index]
147
+ if selected_index == idx:
148
+ # we do not want to expose the predicted sample in demonstrations
149
+ selected_index = cat_index[sample_category][selected_cat_index-1]
150
+ picked_demonstration = self._construct_demonstration(prompts[selected_index],
151
+ answers[selected_index])
152
+ picked_demonstrations.append(picked_demonstration)
153
+ elif self.demos_selection_strategy == "random":
154
+ # evaluation: do not infer samples' difficulty, pick randomly
155
+ selected_index = random.randint(1, len(prompts)-1)
156
+ if selected_index == idx:
157
+ # we do not want to expose the predicted sample in demonstrations
158
+ selected_index -= 1
159
+ picked_demonstration = self._construct_demonstration(prompts[selected_index],
160
+ answers[selected_index])
161
+ picked_demonstrations.append(picked_demonstration)
162
+ else:
163
+ raise ValueError("Unknown demon selection strategy: '%s'" % self.demos_selection_strategy)
164
+ if len(picked_demonstrations) != expected_num_demonstrations:
165
+ # we omit examples with none or only one demonstration in the category
166
+ continue
167
+
168
+ # encode a yielded batch
169
+ primed_prompt = self._construct_primed_prompt(picked_demonstrations, pred_prompt)
170
+
171
+ primed_prompt_encoding = self.tokenizer(primed_prompt, truncation=True)
172
+ label_encoding = self.tokenizer(pred_answer, truncation=True)
173
+
174
+ features_batch.append({"input_ids": primed_prompt_encoding.input_ids,
175
+ "attention_mask": primed_prompt_encoding.attention_mask,
176
+ "labels": label_encoding.input_ids})
177
+ if len(features_batch) == self.batch_size:
178
+ yield self.collator(features_batch)
179
+ features_batch = []
180
+
181
+ retrieved_samples += 1
182
+ if split == "eval" and retrieved_samples >= self.max_eval_samples:
183
+ # custom evaluation break - we need all samples in set to match categories,
184
+ # but do not want to iterate them all
185
+ break
186
+
187
+ if features_batch:
188
+ # yield last nonempty residual batch
189
+ yield self.collator(features_batch)
190
+
191
+ def _compute_loss(self,
192
+ lm_logit_outputs: torch.FloatTensor,
193
+ labels: torch.LongTensor,
194
+ inputs: Optional[Union[BatchEncoding, Dict[str, torch.Tensor]]] = None) -> torch.FloatTensor:
195
+ # customization for mt5 model, with incorrectly-set tokenizer.vocab_size
196
+ # This should be fixed in upcoming release of adaptor (>=0.1.5)
197
+ loss_fct = torch.nn.CrossEntropyLoss()
198
+ lm_loss = loss_fct(lm_logit_outputs.flatten(end_dim=1), labels.flatten())
199
+
200
+ return lm_loss
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ adaptor[generative]==0.2.0
2
+ torch==1.11.0
3
+ pandas
4
+ nltk
train_mt5_qa_en_AQA+ru_info.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from adaptor.adapter import Adapter
4
+ from adaptor.evaluators.generative import BLEU
5
+ from adaptor.lang_module import LangModule
6
+ from adaptor.schedules import ParallelSchedule
7
+ from adaptor.utils import AdaptationArguments, StoppingStrategy
8
+ from datasets import load_dataset
9
+
10
+ from priming_objective import Priming
11
+
12
+ training_arguments = AdaptationArguments(output_dir="train_dir_AQA_info_large_ru",
13
+ learning_rate=2e-5, # we set LR=2e-4 for pre-training experiments
14
+ stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED,
15
+ # stopping_strategy=StoppingStrategy.NUM_STEPS_TOTAL,
16
+ do_train=True,
17
+ do_eval=True,
18
+ warmup_steps=1000,
19
+ max_steps=10000,
20
+ gradient_accumulation_steps=30,
21
+ eval_steps=500,
22
+ logging_steps=10,
23
+ save_steps=500,
24
+ num_train_epochs=5,
25
+ evaluation_strategy="steps",
26
+ save_total_limit=10,
27
+ stopping_patience=10)
28
+ eval_examples = 200
29
+
30
+ # priming
31
+ num_demonstrations = 3
32
+
33
+
34
+ def _construct_priming_prompt(previous_examples: List[str], current_example: str) -> str:
35
+ return " ".join(previous_examples + [current_example])
36
+
37
+
38
+ lang_module = LangModule("google/mt5-large")
39
+
40
+ # priming
41
+ per_type_examples = {}
42
+
43
+ qa_en = load_dataset("adversarial_qa", "adversarialQA")
44
+ qa_train = qa_en["train"].filter(lambda entry: len(entry["context"]) < 2000)
45
+
46
+ val_metrics = [BLEU(**{"additional_sep_char": "▁"})]
47
+
48
+ # Adversarial QA dataset & objective:
49
+
50
+
51
+ def _get_firstword_categories(data) -> List[str]:
52
+ return [question.split()[0] if not question.startswith("To")
53
+ else " ".join(question.split()[:2])
54
+ for question in data["question"]]
55
+
56
+
57
+ q_answering_en = Priming(lang_module,
58
+ max_eval_samples=eval_examples,
59
+ demos_selection_strategy="informative",
60
+ texts_or_path=qa_train["question"],
61
+ text_pair_or_path=qa_train["context"],
62
+ val_texts_or_path=qa_en["validation"]["question"][-eval_examples:],
63
+ val_text_pair_or_path=qa_en["validation"]["context"][-eval_examples:],
64
+ labels_or_path=[a["text"][0] for a in qa_train["answers"]],
65
+ val_labels_or_path=[a["text"][0] for a in qa_en["validation"]["answers"]][-eval_examples:],
66
+ train_question_categories=_get_firstword_categories(qa_train),
67
+ val_question_categories=_get_firstword_categories(qa_en["validation"])[-eval_examples:],
68
+ batch_size=1,
69
+ val_evaluators=val_metrics,
70
+ # val_evaluators=val_metrics,
71
+ source_lang_id="en",
72
+ objective_id="AQA-en")
73
+
74
+ qa_ru = load_dataset("sberquad")
75
+ qa_ru_train = qa_ru["train"].filter(lambda entry: len(entry["context"]) < 800)
76
+
77
+
78
+ skipped = 0
79
+
80
+ q_answering_ru = Priming(lang_module,
81
+ max_eval_samples=eval_examples,
82
+ demos_selection_strategy="informative",
83
+ texts_or_path=qa_ru_train["question"],
84
+ text_pair_or_path=qa_ru_train["context"],
85
+ val_texts_or_path=qa_ru["validation"]["question"][-eval_examples:],
86
+ val_text_pair_or_path=qa_ru["validation"]["context"][-eval_examples:],
87
+ labels_or_path=[a["text"][0] for a in qa_ru_train["answers"]],
88
+ val_labels_or_path=[a["text"][0] for a in qa_ru["validation"]["answers"]][-eval_examples:],
89
+ train_question_categories=_get_firstword_categories(qa_ru_train),
90
+ val_question_categories=_get_firstword_categories(qa_ru["validation"])[-eval_examples:],
91
+ batch_size=1,
92
+ val_evaluators=val_metrics,
93
+ # val_evaluators=val_metrics,
94
+ source_lang_id="ru",
95
+ objective_id="SQuAD-ru")
96
+
97
+ schedule = ParallelSchedule(objectives=[q_answering_en,
98
+ q_answering_ru
99
+ ],
100
+ args=training_arguments)
101
+
102
+ adapter = Adapter(lang_module, schedule, args=training_arguments)
103
+ adapter.train()