rookiemango commited on
Commit
33ab1c8
·
verified ·
1 Parent(s): 1be87dd

Upload folder using huggingface_hub

Browse files
all_code.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tqdm
2
+ import json
3
+ def filtered():
4
+ import json
5
+ with open("data/lean4_basic/1k_test.jsonl", "r") as f:
6
+ test_data = json.load(f)
7
+ test_data = [item['statement_poof'] for item in test_data]
8
+
9
+ # Function to filter items based on existence in test data
10
+
11
+
12
+ with open("data/lean4_random/5k_second.json", "r") as f:
13
+ second_5k = json.load(f)
14
+
15
+ def filter_items(data, test_data):
16
+ filtered_data = [item for item in tqdm.tqdm(data) if item['statement_poof'][:-2] not in test_data]
17
+ return filtered_data
18
+
19
+ # Filter and save filtered data
20
+
21
+ filtered_second_5k = filter_items(second_5k, test_data)
22
+ with open("data/lean4_random/5k_second_filtered.json", "w") as f:
23
+ json.dump(filtered_second_5k, f, ensure_ascii=False, indent=2)
24
+ print("Filtered second file length:", len(filtered_second_5k))
25
+
26
+
27
+
28
+
29
+ def insert_label_for_autoformalization():
30
+ input_lists = ["data/lean4_statement_translate/15k_state_problem_translation.json"]
31
+ for input_file in input_lists:
32
+ with open(input_file, "r") as f:
33
+ test_data = json.load(f)
34
+ for item in test_data:
35
+ item['task']='statement_form'
36
+
37
+ with open("data/lean4_statement_translate/15k_state_problem_translation_statement_form.json", "w") as f:
38
+ json.dump(test_data, f, ensure_ascii=False, indent=2)
39
+
40
+
41
+ input_lists = ["data/lean4_statement_translate/15k_state_problem_translation.json"]
42
+ for input_file in input_lists:
43
+ with open(input_file, "r") as f:
44
+ test_data = json.load(f)
45
+ for item in test_data:
46
+ item['task']='statementproof_inform'
47
+ with open("data/lean4_statement_translate/15k_state_problem_translation_statementproof_inform.json", "w") as f:
48
+ json.dump(test_data, f, ensure_ascii=False, indent=2)
49
+
50
+
51
+ input_lists = ["all_theorem.jsonl"]
52
+ for input_file in input_lists:
53
+ with open(input_file, "r") as f:
54
+ test_data = json.load(f)
55
+ for item in test_data:
56
+ item['task']='solver'
57
+
58
+ with open("data/all_theorem_solver.jsonl", "w") as f:
59
+ json.dump(test_data, f, ensure_ascii=False, indent=2)
60
+
61
+
62
+
63
+
64
+ if __name__ == '__main__':
65
+ insert_label_for_autoformalization()
66
+ # filtered()
auto_backform.py ADDED
File without changes
autobackform_train.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # Modified by Zheng Yuan and Hongyi Yuan
15
+
16
+ import os
17
+ import copy
18
+ import logging
19
+ from dataclasses import dataclass, field
20
+ from typing import Optional, Dict, Sequence
21
+ import io
22
+ import torch
23
+ import transformers
24
+ from torch.utils.data import Dataset
25
+ from transformers import Trainer
26
+ import argparse
27
+ import json
28
+ import random;random.seed(42)
29
+
30
+ def _make_r_io_base(f, mode: str):
31
+ if not isinstance(f, io.IOBase):
32
+ f = open(f, mode=mode)
33
+ return f
34
+
35
+ def jload(f, mode="r"):
36
+ """Load a .json file into a dictionary."""
37
+ f = _make_r_io_base(f, mode)
38
+ jdict = json.load(f)
39
+ f.close()
40
+ return jdict
41
+
42
+ IGNORE_INDEX = -100
43
+ DEFAULT_PAD_TOKEN = "[PAD]"
44
+ DEFAULT_EOS_TOKEN = "</s>"
45
+ DEFAULT_BOS_TOKEN = "<s>"
46
+ DEFAULT_UNK_TOKEN = "<unk>"
47
+ PROMPT_DICT = {
48
+ "lean4": (
49
+ "Statement and proof in natural language:\n\n"
50
+ "{statement_text}\n\n"
51
+ "Translate the statement and proof in natural language to lean4:"
52
+ ),
53
+ "backform": (
54
+ "Statement and proof in lean4:\n\n"
55
+ "{statement_text}\n\n"
56
+ "Translate the statement and proof in lean4 to natural language:"
57
+ ),
58
+ "prompt_no_input": (
59
+ "Below is an instruction that describes a task. "
60
+ "Write a response that appropriately completes the request.\n\n"
61
+ "### Instruction:\n{instruction}\n\n### Response:"
62
+ ),
63
+ }
64
+ #### 28
65
+ @dataclass
66
+ class ModelArguments:
67
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
68
+
69
+
70
+ @dataclass
71
+ class DataArguments:
72
+ data_path: str = field(default=None, metadata={"help": "Path to the training data."})
73
+
74
+
75
+ @dataclass
76
+ class TrainingArguments(transformers.TrainingArguments):
77
+ cache_dir: Optional[str] = field(default=None)
78
+ optim: str = field(default="adamw_torch")
79
+ model_max_length: int = field(
80
+ default=2048,
81
+ metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
82
+ )
83
+ overwrite_output_dir: bool = field(default=True)
84
+
85
+
86
+ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
87
+ """Collects the state dict and dump to disk."""
88
+ state_dict = trainer.model.state_dict()
89
+ if trainer.args.should_save:
90
+ cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
91
+ del state_dict
92
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
93
+
94
+
95
+ def smart_tokenizer_and_embedding_resize(
96
+ special_tokens_dict: Dict,
97
+ tokenizer: transformers.PreTrainedTokenizer,
98
+ model: transformers.PreTrainedModel,
99
+ ):
100
+ """Resize tokenizer and embedding.
101
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
102
+ """
103
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
104
+ model.resize_token_embeddings(len(tokenizer))
105
+
106
+ if num_new_tokens > 0:
107
+ input_embeddings = model.get_input_embeddings().weight.data
108
+ output_embeddings = model.get_output_embeddings().weight.data
109
+
110
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
111
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
112
+
113
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
114
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
115
+
116
+
117
+ def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
118
+ """Tokenize a list of strings."""
119
+ tokenized_list = [
120
+ tokenizer(
121
+ text,
122
+ return_tensors="pt",
123
+ padding="longest",
124
+ max_length=tokenizer.model_max_length,
125
+ truncation=True,
126
+ )
127
+ for text in strings
128
+ ]
129
+ input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
130
+ input_ids_lens = labels_lens = [
131
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
132
+ ]
133
+ return dict(
134
+ input_ids=input_ids,
135
+ labels=labels,
136
+ input_ids_lens=input_ids_lens,
137
+ labels_lens=labels_lens,
138
+ )
139
+
140
+
141
+ def preprocess(
142
+ sources: Sequence[str],
143
+ targets: Sequence[str],
144
+ tokenizer: transformers.PreTrainedTokenizer,
145
+ ) -> Dict:
146
+ """Preprocess the data by tokenizing."""
147
+ examples = [s + t for s, t in zip(sources, targets)]
148
+ examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
149
+ input_ids = examples_tokenized["input_ids"]
150
+ labels = copy.deepcopy(input_ids)
151
+ for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
152
+ label[:source_len] = IGNORE_INDEX
153
+ return dict(input_ids=input_ids, labels=labels)
154
+
155
+ class SupervisedDataset(Dataset):
156
+ """Dataset for supervised fine-tuning."""
157
+ def __init__(self, data_args, tokenizer: transformers.PreTrainedTokenizer):
158
+ super(SupervisedDataset, self).__init__()
159
+ logging.warning("Loading data...")
160
+ data_path = data_args.data_path
161
+ try:
162
+ data_path = data_path_map[data_path]
163
+ except:
164
+ data_path = data_path
165
+ list_data_dict = []
166
+ for item in data_path.split(','):
167
+ try:
168
+ list_data_dict += jload(item)
169
+
170
+ except BaseException:
171
+ with open(item, 'r') as f:
172
+ lines = f.readlines()
173
+ list_data_dict += [json.loads(line.strip()) for line in lines]
174
+
175
+ list_data_dict = random.sample(list_data_dict, len(list_data_dict))
176
+ list_data_dict = list_data_dict[:data_args.data_length]
177
+
178
+ logging.warning("Formatting inputs...")
179
+ prompt_lean4 = PROMPT_DICT["backform"]
180
+
181
+ # list_data_dict = [{'instruction':data['items'][0]['value'], 'input':'', 'output':data['items'][1]['value']} for data in list_data_dict]
182
+
183
+ list_data_dict = [{'instruction':prompt_lean4.format(statement_text = data['statement_poof']), 'input':'', 'output':data['model_response']} for data in list_data_dict]
184
+ print(f"len of {len(list_data_dict)}")
185
+ sources = [example['instruction'] for example in list_data_dict]
186
+ targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
187
+ # targets = [example['output'] for example in list_data_dict]
188
+
189
+ self.sources = sources
190
+ self.targets = targets
191
+
192
+ def __len__(self):
193
+ return len(self.sources)
194
+
195
+ def naive__getitem__(self, i) -> Dict[str, torch.Tensor]:
196
+ return dict(input_ids=self.input_ids[i], labels=self.labels[i])
197
+
198
+ def __getitem__(self, i):
199
+ return dict(input_ids=self.sources[i], labels=self.targets[i])
200
+
201
+ @dataclass
202
+ class DataCollatorForSupervisedDataset(object):
203
+ """Collate examples for supervised fine-tuning."""
204
+
205
+ tokenizer: transformers.PreTrainedTokenizer
206
+
207
+ def naive__call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
208
+ input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
209
+ input_ids = torch.nn.utils.rnn.pad_sequence(
210
+ input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
211
+ )
212
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
213
+ return dict(
214
+ input_ids=input_ids,
215
+ labels=labels,
216
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
217
+ )
218
+
219
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
220
+ sources = []
221
+ targets = []
222
+ for instance in instances:
223
+ source = instance['input_ids']
224
+ target = instance['labels']
225
+ sources.append(source)
226
+ targets.append(target)
227
+
228
+ data_dict = preprocess(sources, targets, self.tokenizer)
229
+ input_ids, labels = data_dict['input_ids'], data_dict['labels']
230
+ # input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
231
+ input_ids = torch.nn.utils.rnn.pad_sequence(
232
+ input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
233
+ )
234
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
235
+ return dict(
236
+ input_ids=input_ids,
237
+ labels=labels,
238
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
239
+ )
240
+
241
+ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
242
+ """Make dataset and collator for supervised fine-tuning."""
243
+ train_dataset = SupervisedDataset(tokenizer=tokenizer, data_args=data_args)
244
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
245
+ return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
246
+
247
+
248
+ os.environ["WANDB_PROJECT"] = "auto_backform"
249
+
250
+ def train():
251
+
252
+ parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
253
+ model_args, data_args, training_args, remaining_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)
254
+ data_args.data_length = int(remaining_args[1])
255
+
256
+ model = transformers.AutoModelForCausalLM.from_pretrained(
257
+ model_args.model_name_or_path,
258
+ cache_dir=training_args.cache_dir,
259
+ trust_remote_code=True,
260
+ torch_dtype=torch.bfloat16,
261
+ attn_implementation="flash_attention_2",
262
+ )
263
+
264
+ model.config.use_cache = False
265
+ model.gradient_checkpointing_enable()
266
+
267
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
268
+ model_args.model_name_or_path,
269
+ cache_dir=training_args.cache_dir,
270
+ model_max_length=training_args.model_max_length,
271
+ padding_side="right",
272
+ use_fast=False,
273
+ )
274
+ if tokenizer.pad_token is None:
275
+ smart_tokenizer_and_embedding_resize(
276
+ special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
277
+ tokenizer=tokenizer,
278
+ model=model,
279
+ )
280
+ if "llama" in model_args.model_name_or_path:
281
+ tokenizer.add_special_tokens(
282
+ {
283
+ "eos_token": DEFAULT_EOS_TOKEN,
284
+ "bos_token": DEFAULT_BOS_TOKEN,
285
+ "unk_token": DEFAULT_UNK_TOKEN,
286
+ }
287
+ )
288
+ try:
289
+ tokenizer.pad_token = tokenizer.unk_token
290
+ except:
291
+ pass
292
+
293
+ data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
294
+ trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
295
+ trainer.train()
296
+ model.config.use_cache = True
297
+ # trainer.save_state()
298
+ # if os.environ.get('LOCAL_RANK') == '0':
299
+ safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
300
+
301
+
302
+
303
+
304
+ if __name__ == "__main__":
305
+ train()
autoform_train.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # Modified by Zheng Yuan and Hongyi Yuan
15
+
16
+ import os
17
+ import copy
18
+ import logging
19
+ from dataclasses import dataclass, field
20
+ from typing import Optional, Dict, Sequence
21
+ import io
22
+ import torch
23
+ import transformers
24
+ from torch.utils.data import Dataset
25
+ from transformers import Trainer
26
+ import argparse
27
+ import json
28
+ import random;random.seed(42)
29
+
30
+ def _make_r_io_base(f, mode: str):
31
+ if not isinstance(f, io.IOBase):
32
+ f = open(f, mode=mode)
33
+ return f
34
+
35
+ def jload(f, mode="r"):
36
+ """Load a .json file into a dictionary."""
37
+ f = _make_r_io_base(f, mode)
38
+ jdict = json.load(f)
39
+ f.close()
40
+ return jdict
41
+
42
+ IGNORE_INDEX = -100
43
+ DEFAULT_PAD_TOKEN = "[PAD]"
44
+ DEFAULT_EOS_TOKEN = "</s>"
45
+ DEFAULT_BOS_TOKEN = "<s>"
46
+ DEFAULT_UNK_TOKEN = "<unk>"
47
+ PROMPT_DICT = {
48
+ "lean4": (
49
+ "Statement and proof in natural language:\n\n"
50
+ "{statement_text}\n\n"
51
+ "Translate the statement and proof in natural language to lean4:"
52
+ ),
53
+ "prompt_no_input": (
54
+ "Below is an instruction that describes a task. "
55
+ "Write a response that appropriately completes the request.\n\n"
56
+ "### Instruction:\n{instruction}\n\n### Response:"
57
+ ),
58
+ }
59
+ #### 28
60
+ @dataclass
61
+ class ModelArguments:
62
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
63
+
64
+
65
+ @dataclass
66
+ class DataArguments:
67
+ data_path: str = field(default=None, metadata={"help": "Path to the training data."})
68
+
69
+
70
+ @dataclass
71
+ class TrainingArguments(transformers.TrainingArguments):
72
+ cache_dir: Optional[str] = field(default=None)
73
+ optim: str = field(default="adamw_torch")
74
+ model_max_length: int = field(
75
+ default=2048,
76
+ metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
77
+ )
78
+ overwrite_output_dir: bool = field(default=True)
79
+
80
+
81
+ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
82
+ """Collects the state dict and dump to disk."""
83
+ state_dict = trainer.model.state_dict()
84
+ if trainer.args.should_save:
85
+ cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
86
+ del state_dict
87
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
88
+
89
+
90
+ def smart_tokenizer_and_embedding_resize(
91
+ special_tokens_dict: Dict,
92
+ tokenizer: transformers.PreTrainedTokenizer,
93
+ model: transformers.PreTrainedModel,
94
+ ):
95
+ """Resize tokenizer and embedding.
96
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
97
+ """
98
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
99
+ model.resize_token_embeddings(len(tokenizer))
100
+
101
+ if num_new_tokens > 0:
102
+ input_embeddings = model.get_input_embeddings().weight.data
103
+ output_embeddings = model.get_output_embeddings().weight.data
104
+
105
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
106
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
107
+
108
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
109
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
110
+
111
+
112
+ def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
113
+ """Tokenize a list of strings."""
114
+ tokenized_list = [
115
+ tokenizer(
116
+ text,
117
+ return_tensors="pt",
118
+ padding="longest",
119
+ max_length=tokenizer.model_max_length,
120
+ truncation=True,
121
+ )
122
+ for text in strings
123
+ ]
124
+ input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
125
+ input_ids_lens = labels_lens = [
126
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
127
+ ]
128
+ return dict(
129
+ input_ids=input_ids,
130
+ labels=labels,
131
+ input_ids_lens=input_ids_lens,
132
+ labels_lens=labels_lens,
133
+ )
134
+
135
+
136
+ def preprocess(
137
+ sources: Sequence[str],
138
+ targets: Sequence[str],
139
+ tokenizer: transformers.PreTrainedTokenizer,
140
+ ) -> Dict:
141
+ """Preprocess the data by tokenizing."""
142
+ examples = [s + t for s, t in zip(sources, targets)]
143
+ examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
144
+ input_ids = examples_tokenized["input_ids"]
145
+ labels = copy.deepcopy(input_ids)
146
+ for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
147
+ label[:source_len] = IGNORE_INDEX
148
+ return dict(input_ids=input_ids, labels=labels)
149
+
150
+ class SupervisedDataset(Dataset):
151
+ """Dataset for supervised fine-tuning."""
152
+ def __init__(self, data_args, tokenizer: transformers.PreTrainedTokenizer):
153
+ super(SupervisedDataset, self).__init__()
154
+ logging.warning("Loading data...")
155
+ data_path = data_args.data_path
156
+ try:
157
+ data_path = data_path_map[data_path]
158
+ except:
159
+ data_path = data_path
160
+ list_data_dict = []
161
+ for item in data_path.split(','):
162
+ try:
163
+ list_data_dict += jload(item)
164
+
165
+ except BaseException:
166
+ with open(item, 'r') as f:
167
+ lines = f.readlines()
168
+ list_data_dict += [json.loads(line.strip()) for line in lines]
169
+
170
+ list_data_dict = random.sample(list_data_dict, len(list_data_dict))
171
+ list_data_dict = list_data_dict[:data_args.data_length]
172
+
173
+ logging.warning("Formatting inputs...")
174
+ prompt_lean4 = PROMPT_DICT["lean4"]
175
+
176
+ # list_data_dict = [{'instruction':data['items'][0]['value'], 'input':'', 'output':data['items'][1]['value']} for data in list_data_dict]
177
+
178
+ list_data_dict = [{'instruction':prompt_lean4.format(statement_text = data['model_response']), 'input':'', 'output':data['statement_poof']} for data in list_data_dict]
179
+ print(f"len of {len(list_data_dict)}")
180
+ sources = [example['instruction'] for example in list_data_dict]
181
+ targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
182
+ # targets = [example['output'] for example in list_data_dict]
183
+
184
+ self.sources = sources
185
+ self.targets = targets
186
+
187
+ def __len__(self):
188
+ return len(self.sources)
189
+
190
+ def naive__getitem__(self, i) -> Dict[str, torch.Tensor]:
191
+ return dict(input_ids=self.input_ids[i], labels=self.labels[i])
192
+
193
+ def __getitem__(self, i):
194
+ return dict(input_ids=self.sources[i], labels=self.targets[i])
195
+
196
+ @dataclass
197
+ class DataCollatorForSupervisedDataset(object):
198
+ """Collate examples for supervised fine-tuning."""
199
+
200
+ tokenizer: transformers.PreTrainedTokenizer
201
+
202
+ def naive__call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
203
+ input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
204
+ input_ids = torch.nn.utils.rnn.pad_sequence(
205
+ input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
206
+ )
207
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
208
+ return dict(
209
+ input_ids=input_ids,
210
+ labels=labels,
211
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
212
+ )
213
+
214
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
215
+ sources = []
216
+ targets = []
217
+ for instance in instances:
218
+ source = instance['input_ids']
219
+ target = instance['labels']
220
+ sources.append(source)
221
+ targets.append(target)
222
+
223
+ data_dict = preprocess(sources, targets, self.tokenizer)
224
+ input_ids, labels = data_dict['input_ids'], data_dict['labels']
225
+ # input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
226
+ input_ids = torch.nn.utils.rnn.pad_sequence(
227
+ input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
228
+ )
229
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
230
+ return dict(
231
+ input_ids=input_ids,
232
+ labels=labels,
233
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
234
+ )
235
+
236
+ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
237
+ """Make dataset and collator for supervised fine-tuning."""
238
+ train_dataset = SupervisedDataset(tokenizer=tokenizer, data_args=data_args)
239
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
240
+ return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
241
+
242
+
243
+ os.environ["WANDB_PROJECT"] = "auto_form"
244
+
245
+ def train():
246
+
247
+ parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
248
+ model_args, data_args, training_args, remaining_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)
249
+ data_args.data_length = int(remaining_args[1])
250
+
251
+ model = transformers.AutoModelForCausalLM.from_pretrained(
252
+ model_args.model_name_or_path,
253
+ cache_dir=training_args.cache_dir,
254
+ trust_remote_code=True,
255
+ torch_dtype=torch.bfloat16,
256
+ attn_implementation="flash_attention_2",
257
+ )
258
+
259
+ model.config.use_cache = False
260
+ model.gradient_checkpointing_enable()
261
+
262
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
263
+ model_args.model_name_or_path,
264
+ cache_dir=training_args.cache_dir,
265
+ model_max_length=training_args.model_max_length,
266
+ padding_side="right",
267
+ use_fast=False,
268
+ )
269
+ if tokenizer.pad_token is None:
270
+ smart_tokenizer_and_embedding_resize(
271
+ special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
272
+ tokenizer=tokenizer,
273
+ model=model,
274
+ )
275
+ if "llama" in model_args.model_name_or_path:
276
+ tokenizer.add_special_tokens(
277
+ {
278
+ "eos_token": DEFAULT_EOS_TOKEN,
279
+ "bos_token": DEFAULT_BOS_TOKEN,
280
+ "unk_token": DEFAULT_UNK_TOKEN,
281
+ }
282
+ )
283
+ try:
284
+ tokenizer.pad_token = tokenizer.unk_token
285
+ except:
286
+ pass
287
+
288
+ data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
289
+ trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
290
+ trainer.train()
291
+ model.config.use_cache = True
292
+ # trainer.save_state()
293
+ # if os.environ.get('LOCAL_RANK') == '0':
294
+ safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
295
+
296
+
297
+
298
+
299
+ if __name__ == "__main__":
300
+ train()
autosolver_train.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # Modified by Zheng Yuan and Hongyi Yuan
15
+
16
+ import os
17
+ import copy
18
+ import logging
19
+ from dataclasses import dataclass, field
20
+ from typing import Optional, Dict, Sequence
21
+ import io
22
+ import torch
23
+ import transformers
24
+ from torch.utils.data import Dataset
25
+ from transformers import Trainer
26
+ import argparse
27
+ import json
28
+ import random;random.seed(42)
29
+
30
+ def _make_r_io_base(f, mode: str):
31
+ if not isinstance(f, io.IOBase):
32
+ f = open(f, mode=mode)
33
+ return f
34
+
35
+ def jload(f, mode="r"):
36
+ """Load a .json file into a dictionary."""
37
+ f = _make_r_io_base(f, mode)
38
+ jdict = json.load(f)
39
+ f.close()
40
+ return jdict
41
+
42
+ IGNORE_INDEX = -100
43
+ DEFAULT_PAD_TOKEN = "[PAD]"
44
+ DEFAULT_EOS_TOKEN = "</s>"
45
+ DEFAULT_BOS_TOKEN = "<s>"
46
+ DEFAULT_UNK_TOKEN = "<unk>"
47
+ PROMPT_DICT = {
48
+ "lean4": (
49
+ "Statement and proof in natural language:\n\n"
50
+ "{statement_text}\n\n"
51
+ "Translate the statement and proof in natural language to lean4:"
52
+ ),
53
+ "plain": (
54
+ "{statement_text}"
55
+ ),
56
+ "prompt_no_input": (
57
+ "Below is an instruction that describes a task. "
58
+ "Write a response that appropriately completes the request.\n\n"
59
+ "### Instruction:\n{instruction}\n\n### Response:"
60
+ ),
61
+ }
62
+ #### 28
63
+ @dataclass
64
+ class ModelArguments:
65
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
66
+
67
+
68
+ @dataclass
69
+ class DataArguments:
70
+ data_path: str = field(default=None, metadata={"help": "Path to the training data."})
71
+
72
+
73
+ @dataclass
74
+ class TrainingArguments(transformers.TrainingArguments):
75
+ cache_dir: Optional[str] = field(default=None)
76
+ optim: str = field(default="adamw_torch")
77
+ model_max_length: int = field(
78
+ default=2048,
79
+ metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
80
+ )
81
+ overwrite_output_dir: bool = field(default=True)
82
+
83
+
84
+ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
85
+ """Collects the state dict and dump to disk."""
86
+ state_dict = trainer.model.state_dict()
87
+ if trainer.args.should_save:
88
+ cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
89
+ del state_dict
90
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
91
+
92
+
93
+ def smart_tokenizer_and_embedding_resize(
94
+ special_tokens_dict: Dict,
95
+ tokenizer: transformers.PreTrainedTokenizer,
96
+ model: transformers.PreTrainedModel,
97
+ ):
98
+ """Resize tokenizer and embedding.
99
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
100
+ """
101
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
102
+ model.resize_token_embeddings(len(tokenizer))
103
+
104
+ if num_new_tokens > 0:
105
+ input_embeddings = model.get_input_embeddings().weight.data
106
+ output_embeddings = model.get_output_embeddings().weight.data
107
+
108
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
109
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
110
+
111
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
112
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
113
+
114
+
115
+ def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
116
+ """Tokenize a list of strings."""
117
+ tokenized_list = [
118
+ tokenizer(
119
+ text,
120
+ return_tensors="pt",
121
+ padding="longest",
122
+ max_length=tokenizer.model_max_length,
123
+ truncation=True,
124
+ )
125
+ for text in strings
126
+ ]
127
+ input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
128
+ input_ids_lens = labels_lens = [
129
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
130
+ ]
131
+ return dict(
132
+ input_ids=input_ids,
133
+ labels=labels,
134
+ input_ids_lens=input_ids_lens,
135
+ labels_lens=labels_lens,
136
+ )
137
+
138
+
139
+ def preprocess(
140
+ sources: Sequence[str],
141
+ targets: Sequence[str],
142
+ tokenizer: transformers.PreTrainedTokenizer,
143
+ ) -> Dict:
144
+ """Preprocess the data by tokenizing."""
145
+ examples = [s + t for s, t in zip(sources, targets)]
146
+ examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
147
+ input_ids = examples_tokenized["input_ids"]
148
+ labels = copy.deepcopy(input_ids)
149
+ for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
150
+ label[:source_len] = IGNORE_INDEX
151
+ return dict(input_ids=input_ids, labels=labels)
152
+
153
+ class SupervisedDataset(Dataset):
154
+ """Dataset for supervised fine-tuning."""
155
+ def __init__(self, data_args, tokenizer: transformers.PreTrainedTokenizer):
156
+ super(SupervisedDataset, self).__init__()
157
+ logging.warning("Loading data...")
158
+ data_path = data_args.data_path
159
+ try:
160
+ data_path = data_path_map[data_path]
161
+ except:
162
+ data_path = data_path
163
+ list_data_dict = []
164
+ for item in data_path.split(','):
165
+ try:
166
+ list_data_dict += jload(item)
167
+
168
+ except BaseException:
169
+ with open(item, 'r') as f:
170
+ lines = f.readlines()
171
+ list_data_dict += [json.loads(line.strip()) for line in lines]
172
+
173
+ list_data_dict = random.sample(list_data_dict, len(list_data_dict))
174
+ list_data_dict = list_data_dict[:data_args.data_length]
175
+
176
+ logging.warning("Formatting inputs...")
177
+ prompt_lean4 = PROMPT_DICT["plain"]
178
+
179
+ # list_data_dict = [{'instruction':data['items'][0]['value'], 'input':'', 'output':data['items'][1]['value']} for data in list_data_dict]
180
+
181
+ list_data_dict = [{'instruction':prompt_lean4.format(statement_text = data['statement']), 'input':'', 'output':data['proof']} for data in list_data_dict]
182
+ print(f"len of {len(list_data_dict)}")
183
+ sources = [example['instruction'] for example in list_data_dict]
184
+ targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
185
+ # targets = [example['output'] for example in list_data_dict]
186
+
187
+ self.sources = sources
188
+ self.targets = targets
189
+
190
+ def __len__(self):
191
+ return len(self.sources)
192
+
193
+ def naive__getitem__(self, i) -> Dict[str, torch.Tensor]:
194
+ return dict(input_ids=self.input_ids[i], labels=self.labels[i])
195
+
196
+ def __getitem__(self, i):
197
+ return dict(input_ids=self.sources[i], labels=self.targets[i])
198
+
199
+ @dataclass
200
+ class DataCollatorForSupervisedDataset(object):
201
+ """Collate examples for supervised fine-tuning."""
202
+
203
+ tokenizer: transformers.PreTrainedTokenizer
204
+
205
+ def naive__call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
206
+ input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
207
+ input_ids = torch.nn.utils.rnn.pad_sequence(
208
+ input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
209
+ )
210
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
211
+ return dict(
212
+ input_ids=input_ids,
213
+ labels=labels,
214
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
215
+ )
216
+
217
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
218
+ sources = []
219
+ targets = []
220
+ for instance in instances:
221
+ source = instance['input_ids']
222
+ target = instance['labels']
223
+ sources.append(source)
224
+ targets.append(target)
225
+
226
+ data_dict = preprocess(sources, targets, self.tokenizer)
227
+ input_ids, labels = data_dict['input_ids'], data_dict['labels']
228
+ # input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
229
+ input_ids = torch.nn.utils.rnn.pad_sequence(
230
+ input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
231
+ )
232
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
233
+ return dict(
234
+ input_ids=input_ids,
235
+ labels=labels,
236
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
237
+ )
238
+
239
+ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
240
+ """Make dataset and collator for supervised fine-tuning."""
241
+ train_dataset = SupervisedDataset(tokenizer=tokenizer, data_args=data_args)
242
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
243
+ return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
244
+
245
+
246
+ os.environ["WANDB_PROJECT"] = "auto_solver"
247
+
248
+ def train():
249
+
250
+ parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
251
+ model_args, data_args, training_args, remaining_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)
252
+ data_args.data_length = int(remaining_args[1])
253
+
254
+ model = transformers.AutoModelForCausalLM.from_pretrained(
255
+ model_args.model_name_or_path,
256
+ cache_dir=training_args.cache_dir,
257
+ trust_remote_code=True,
258
+ torch_dtype=torch.bfloat16,
259
+ attn_implementation="flash_attention_2",
260
+ )
261
+
262
+ model.config.use_cache = False
263
+ model.gradient_checkpointing_enable()
264
+
265
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
266
+ model_args.model_name_or_path,
267
+ cache_dir=training_args.cache_dir,
268
+ model_max_length=training_args.model_max_length,
269
+ padding_side="right",
270
+ use_fast=False,
271
+ )
272
+ if tokenizer.pad_token is None:
273
+ smart_tokenizer_and_embedding_resize(
274
+ special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
275
+ tokenizer=tokenizer,
276
+ model=model,
277
+ )
278
+ if "llama" in model_args.model_name_or_path:
279
+ tokenizer.add_special_tokens(
280
+ {
281
+ "eos_token": DEFAULT_EOS_TOKEN,
282
+ "bos_token": DEFAULT_BOS_TOKEN,
283
+ "unk_token": DEFAULT_UNK_TOKEN,
284
+ }
285
+ )
286
+ try:
287
+ tokenizer.pad_token = tokenizer.unk_token
288
+ except:
289
+ pass
290
+
291
+ data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
292
+ trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
293
+ trainer.train()
294
+ model.config.use_cache = True
295
+ # trainer.save_state()
296
+ # if os.environ.get('LOCAL_RANK') == '0':
297
+ safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
298
+
299
+
300
+
301
+
302
+ if __name__ == "__main__":
303
+ train()
autostatement_train.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # Modified by Zheng Yuan and Hongyi Yuan
15
+
16
+ import os
17
+ import copy
18
+ import logging
19
+ from dataclasses import dataclass, field
20
+ from typing import Optional, Dict, Sequence
21
+ import io
22
+ import torch
23
+ import transformers
24
+ from torch.utils.data import Dataset
25
+ from transformers import Trainer
26
+ import argparse
27
+ import json
28
+ import random;random.seed(42)
29
+
30
+ def _make_r_io_base(f, mode: str):
31
+ if not isinstance(f, io.IOBase):
32
+ f = open(f, mode=mode)
33
+ return f
34
+
35
+ def jload(f, mode="r"):
36
+ """Load a .json file into a dictionary."""
37
+ f = _make_r_io_base(f, mode)
38
+ jdict = json.load(f)
39
+ f.close()
40
+ return jdict
41
+
42
+ IGNORE_INDEX = -100
43
+ DEFAULT_PAD_TOKEN = "[PAD]"
44
+ DEFAULT_EOS_TOKEN = "</s>"
45
+ DEFAULT_BOS_TOKEN = "<s>"
46
+ DEFAULT_UNK_TOKEN = "<unk>"
47
+ PROMPT_DICT = {
48
+ "lean4": (
49
+ "Statement and proof in natural language:\n\n"
50
+ "{statement_text}\n\n"
51
+ "Translate the statement and proof in natural language to lean4:"
52
+ ),
53
+ "plain": (
54
+ "{statement_text}"
55
+ ),
56
+ "statement": (
57
+ "Statement in natural language:\n"
58
+ "{problem}\n"
59
+ "Translate the statement in natural language to Lean4:"
60
+ ),
61
+ "prompt_no_input": (
62
+ "Below is an instruction that describes a task. "
63
+ "Write a response that appropriately completes the request.\n\n"
64
+ "### Instruction:\n{instruction}\n\n### Response:"
65
+ ),
66
+ }
67
+ #### 28
68
+ @dataclass
69
+ class ModelArguments:
70
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
71
+
72
+
73
+ @dataclass
74
+ class DataArguments:
75
+ data_path: str = field(default=None, metadata={"help": "Path to the training data."})
76
+
77
+
78
+ @dataclass
79
+ class TrainingArguments(transformers.TrainingArguments):
80
+ cache_dir: Optional[str] = field(default=None)
81
+ optim: str = field(default="adamw_torch")
82
+ model_max_length: int = field(
83
+ default=2048,
84
+ metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
85
+ )
86
+ overwrite_output_dir: bool = field(default=True)
87
+
88
+
89
+ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
90
+ """Collects the state dict and dump to disk."""
91
+ state_dict = trainer.model.state_dict()
92
+ if trainer.args.should_save:
93
+ cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
94
+ del state_dict
95
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
96
+
97
+
98
+ def smart_tokenizer_and_embedding_resize(
99
+ special_tokens_dict: Dict,
100
+ tokenizer: transformers.PreTrainedTokenizer,
101
+ model: transformers.PreTrainedModel,
102
+ ):
103
+ """Resize tokenizer and embedding.
104
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
105
+ """
106
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
107
+ model.resize_token_embeddings(len(tokenizer))
108
+
109
+ if num_new_tokens > 0:
110
+ input_embeddings = model.get_input_embeddings().weight.data
111
+ output_embeddings = model.get_output_embeddings().weight.data
112
+
113
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
114
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
115
+
116
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
117
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
118
+
119
+
120
+ def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
121
+ """Tokenize a list of strings."""
122
+ tokenized_list = [
123
+ tokenizer(
124
+ text,
125
+ return_tensors="pt",
126
+ padding="longest",
127
+ max_length=tokenizer.model_max_length,
128
+ truncation=True,
129
+ )
130
+ for text in strings
131
+ ]
132
+ input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
133
+ input_ids_lens = labels_lens = [
134
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
135
+ ]
136
+ return dict(
137
+ input_ids=input_ids,
138
+ labels=labels,
139
+ input_ids_lens=input_ids_lens,
140
+ labels_lens=labels_lens,
141
+ )
142
+
143
+
144
+ def preprocess(
145
+ sources: Sequence[str],
146
+ targets: Sequence[str],
147
+ tokenizer: transformers.PreTrainedTokenizer,
148
+ ) -> Dict:
149
+ """Preprocess the data by tokenizing."""
150
+ examples = [s + t for s, t in zip(sources, targets)]
151
+ examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
152
+ input_ids = examples_tokenized["input_ids"]
153
+ labels = copy.deepcopy(input_ids)
154
+ for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
155
+ label[:source_len] = IGNORE_INDEX
156
+ return dict(input_ids=input_ids, labels=labels)
157
+
158
+ class SupervisedDataset(Dataset):
159
+ """Dataset for supervised fine-tuning."""
160
+ def __init__(self, data_args, tokenizer: transformers.PreTrainedTokenizer):
161
+ super(SupervisedDataset, self).__init__()
162
+ logging.warning("Loading data...")
163
+ data_path = data_args.data_path
164
+ try:
165
+ data_path = data_path_map[data_path]
166
+ except:
167
+ data_path = data_path
168
+ list_data_dict = []
169
+ for item in data_path.split(','):
170
+ try:
171
+ list_data_dict += jload(item)
172
+
173
+ except BaseException:
174
+ with open(item, 'r') as f:
175
+ lines = f.readlines()
176
+ list_data_dict += [json.loads(line.strip()) for line in lines]
177
+
178
+ list_data_dict = random.sample(list_data_dict, len(list_data_dict))
179
+ list_data_dict = list_data_dict[:data_args.data_length]
180
+
181
+ logging.warning("Formatting inputs...")
182
+ prompt_lean4 = PROMPT_DICT["statement"]
183
+
184
+ # list_data_dict = [{'instruction':data['items'][0]['value'], 'input':'', 'output':data['items'][1]['value']} for data in list_data_dict]
185
+
186
+ list_data_dict = [{'instruction':prompt_lean4.format(problem= data['problem']), 'input':'', 'output':data['statement']} for data in list_data_dict]
187
+ print(f"len of {len(list_data_dict)}")
188
+ sources = [example['instruction'] for example in list_data_dict]
189
+ targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
190
+ # targets = [example['output'] for example in list_data_dict]
191
+
192
+ self.sources = sources
193
+ self.targets = targets
194
+
195
+ def __len__(self):
196
+ return len(self.sources)
197
+
198
+ def naive__getitem__(self, i) -> Dict[str, torch.Tensor]:
199
+ return dict(input_ids=self.input_ids[i], labels=self.labels[i])
200
+
201
+ def __getitem__(self, i):
202
+ return dict(input_ids=self.sources[i], labels=self.targets[i])
203
+
204
+ @dataclass
205
+ class DataCollatorForSupervisedDataset(object):
206
+ """Collate examples for supervised fine-tuning."""
207
+
208
+ tokenizer: transformers.PreTrainedTokenizer
209
+
210
+ def naive__call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
211
+ input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
212
+ input_ids = torch.nn.utils.rnn.pad_sequence(
213
+ input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
214
+ )
215
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
216
+ return dict(
217
+ input_ids=input_ids,
218
+ labels=labels,
219
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
220
+ )
221
+
222
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
223
+ sources = []
224
+ targets = []
225
+ for instance in instances:
226
+ source = instance['input_ids']
227
+ target = instance['labels']
228
+ sources.append(source)
229
+ targets.append(target)
230
+
231
+ data_dict = preprocess(sources, targets, self.tokenizer)
232
+ input_ids, labels = data_dict['input_ids'], data_dict['labels']
233
+ # input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
234
+ input_ids = torch.nn.utils.rnn.pad_sequence(
235
+ input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
236
+ )
237
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
238
+ return dict(
239
+ input_ids=input_ids,
240
+ labels=labels,
241
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
242
+ )
243
+
244
+ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
245
+ """Make dataset and collator for supervised fine-tuning."""
246
+ train_dataset = SupervisedDataset(tokenizer=tokenizer, data_args=data_args)
247
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
248
+ return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
249
+
250
+
251
+ os.environ["WANDB_PROJECT"] = "auto_statement"
252
+
253
+ def train():
254
+
255
+ parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
256
+ model_args, data_args, training_args, remaining_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)
257
+ data_args.data_length = int(remaining_args[1])
258
+
259
+ model = transformers.AutoModelForCausalLM.from_pretrained(
260
+ model_args.model_name_or_path,
261
+ cache_dir=training_args.cache_dir,
262
+ trust_remote_code=True,
263
+ torch_dtype=torch.bfloat16,
264
+ attn_implementation="flash_attention_2",
265
+ )
266
+
267
+ model.config.use_cache = False
268
+ model.gradient_checkpointing_enable()
269
+
270
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
271
+ model_args.model_name_or_path,
272
+ cache_dir=training_args.cache_dir,
273
+ model_max_length=training_args.model_max_length,
274
+ padding_side="right",
275
+ use_fast=False,
276
+ )
277
+ if tokenizer.pad_token is None:
278
+ smart_tokenizer_and_embedding_resize(
279
+ special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
280
+ tokenizer=tokenizer,
281
+ model=model,
282
+ )
283
+ if "llama" in model_args.model_name_or_path:
284
+ tokenizer.add_special_tokens(
285
+ {
286
+ "eos_token": DEFAULT_EOS_TOKEN,
287
+ "bos_token": DEFAULT_BOS_TOKEN,
288
+ "unk_token": DEFAULT_UNK_TOKEN,
289
+ }
290
+ )
291
+ try:
292
+ tokenizer.pad_token = tokenizer.unk_token
293
+ except:
294
+ pass
295
+
296
+ data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
297
+ trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
298
+ trainer.train()
299
+ model.config.use_cache = True
300
+ # trainer.save_state()
301
+ # if os.environ.get('LOCAL_RANK') == '0':
302
+ safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
303
+
304
+
305
+
306
+
307
+ if __name__ == "__main__":
308
+ train()
generation_method.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+ import glob
4
+
5
+ from tqdm import tqdm
6
+ import re
7
+ import sys
8
+ import os
9
+ import numpy as np
10
+
11
+
12
+ def generate_few_shot(prompt):
13
+ base_gsm8k_list = [
14
+ {
15
+ 'question': "John and his best friend Steve bought 12 cupcakes together. Each cupcake cost $1.50. If they split the costs evenly, how much did each person pay?",
16
+ 'answer': "The total cost of cupcakes was 1.5*12=$<<1.5*12=18>>18\\nSo they each paid 18/2=$<<18/2=9>>9.",
17
+ 'direct_answer': "9"
18
+ },
19
+ {
20
+ 'question': "Lizzy has to ship 540 pounds of fish that are packed into 30-pound crates. If the shipping cost of each crate is $1.5, how much will Lizzy pay for the shipment?",
21
+ 'answer': "There are 540 pounds / 30 pounds/crate = <<540/30=18>>18 crates of fish needed.\\nHence, the total cost for the shipment is $1.5/crate x 18 crates = $<<1.5*18=27>>27.",
22
+ 'direct_answer': "27"
23
+ },
24
+ {
25
+ 'question': "Tom, Tim, and Paul are collecting photos of cars. Paul has 10 photos more than Tim. Tim has one hundred photos less than the total amount of photos which is 152. How many photos does Tom have?",
26
+ 'answer': "Tim has 152 photos - 100 photos = <<152-100=52>>52 photos.\\nWhen Tim has 52 photos, then Paul has 52 + 10 photos = <<52+10=62>>62 photos.\\nTim and Paul have together 52 photos + 62 photos = <<52+62=114>>114 photos.\\nThat leaves Tom with 152 photos - 114 photos = <<152-114=38>>38 photos.",
27
+ 'direct_answer': "38"
28
+ },
29
+
30
+ ]
31
+ index_list = list(range(len(base_gsm8k_list)))
32
+ random.shuffle(index_list)
33
+ few_shot_example = ""
34
+ for i in index_list:
35
+ item = base_gsm8k_list[i]
36
+ few_shot_example += "Q: " + item['question'] + "\n" + "A: "+ item['answer'] + "\nThe answer is " + item['direct_answer'] + "\n"
37
+
38
+ few_shot_example += "Q: " + prompt + "A: "
39
+ return few_shot_example
40
+
41
+
42
+
43
+ def generate_prompt_generation(args, question):
44
+ if args.evaluation_mode == 'generation':
45
+ if args.method == 'zero_shot_cot':
46
+ content = question + " Let's think step by step."
47
+ elif args.method == 'zero_shot':
48
+ content = question
49
+ elif args.method == 'few_shot':
50
+ content = generate_few_shot(question)
51
+ else:
52
+ raise ValueError("we do not method for such model type yet")
53
+
54
+ if "generator" not in args.model_type:
55
+ MODEL_DICT = {
56
+ "llama": (
57
+ "[INST] \n{content}\n [/INST]"
58
+ ),
59
+ "mistral": (
60
+ "<s>[INST] {content} [/INST]"
61
+ ),
62
+ "chatglm": (
63
+ "<|user|> \n{content}\n <|assistant|>"
64
+ ),
65
+ "qianwen": (
66
+ "<|im_start|>user\n{content}<|im_end|>\n<|im_start|>assistant\n"
67
+ ),
68
+ "baichuan": (
69
+ "<reserved_106>{content}<reserved_107>"
70
+ )
71
+ }
72
+
73
+ if args.model_type in ["qianwen", "qianwen-13b", "qianwen-70b"]:
74
+ content = MODEL_DICT['qianwen'].format_map(
75
+ {'content': content}
76
+ )
77
+
78
+ elif args.model_type in ["chatglm"]:
79
+ pass
80
+
81
+
82
+ elif args.model_type in ['llama2-7b-chat']:
83
+ content = MODEL_DICT['llama'].format_map(
84
+ {'content': content}
85
+ )
86
+
87
+ elif args.model_type in ["mistral", 'mixtral']:
88
+ content = MODEL_DICT['mistral'].format_map(
89
+ {'content': content}
90
+ )
91
+
92
+
93
+ return content
94
+
95
+
96
+
97
+
98
+ few_shot_list = [
99
+ {
100
+ 'question': "There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?",
101
+ 'answer': "There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6.",
102
+ 'direct_answer': "6"
103
+ },
104
+ {
105
+ 'question': "If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?",
106
+ 'answer': "There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5.",
107
+ 'direct_answer': "5",
108
+ },
109
+ {
110
+ 'question': "Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?",
111
+ 'answer': "Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39.",
112
+ 'direct_answer': "39",
113
+ },
114
+ {
115
+ 'question': "Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?",
116
+ 'answer': "Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8.",
117
+ 'direct_answer': "8",
118
+ },
119
+ {
120
+ 'question': "Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?",
121
+ 'answer': "Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9.",
122
+ 'direct_answer': "9",
123
+ },
124
+ {
125
+ 'question': "There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?",
126
+ 'answer': "There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29.",
127
+ 'direct_answer': "29",
128
+ },
129
+ {
130
+ 'question': "Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?",
131
+ 'answer': "Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls.",
132
+ 'direct_answer': "33",
133
+ },
134
+ {
135
+ 'question': "Olivia has $23. She bought five bagels for $3 each. How much money does she have left?",
136
+ 'answer': "Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8.",
137
+ 'direct_answer': "8",
138
+ },
139
+ ]
140
+ import json
141
+
142
+ from collections import Counter
143
+
144
+
145
+ def self_consistency(pairs):
146
+ val_counts = Counter(value for key, value in pairs)
147
+ most = val_counts.most_common(1)[0][0]
148
+ for key, value in pairs:
149
+ if value == most:
150
+ return key
151
+
152
+
153
+ #
154
+ def find_feedback(content):
155
+ match = re.search(r'Judgement: (.+)', content)
156
+ if match:
157
+ judgement = match.group(1)
158
+ else:
159
+ judgement = "None"
160
+ return judgement
161
+
162
+
163
+ def str2bool(s):
164
+ s = s.lower()
165
+ if s == 'true':
166
+ return True
167
+ elif s == 'false':
168
+ return False
169
+ else:
170
+ raise ValueError('invalid value: {}, must be true or false'.format(s))
171
+
172
+
173
+ def parse_arguments():
174
+ parser = argparse.ArgumentParser(description="Zero-shot-CoT")
175
+
176
+ # parser.add_argument(
177
+ # "--dataset", type=str, default="plan",
178
+ # choices=["plan", 'tool_use_awareness', 'tool_selection', 'tool_selection_harder', 'tool_creation_awareness',
179
+ # 'tool_creation_awareness_harder', 'tool_creation',
180
+ # 'arguments_filling'], help="dataset used for experiment")
181
+ parser.add_argument(
182
+ "--cot_trigger_no", type=int, default=1,
183
+ help="A trigger sentence that elicits a model to execute chain of thought"
184
+ )
185
+ parser.add_argument("--dataset", type=str, default="")
186
+ parser.add_argument("--data_path", type=str, default="")
187
+ parser.add_argument("--evaluation_mode", type=str, default="")
188
+ parser.add_argument("--batch_size", type=int, default=1)
189
+ parser.add_argument("--eval_method", type=str, default="")
190
+
191
+ parser.add_argument("--model_path", type=str, default="")
192
+
193
+ parser.add_argument("--model_type", type=str, default="chatglm")
194
+
195
+ parser.add_argument("--output_dir", type=str, default="generation_test")
196
+
197
+ parser.add_argument("--lora_path", type=str, default="")
198
+
199
+ parser.add_argument("--iter_num", type=int, default=1)
200
+ parser.add_argument("--method", type=str, default="few_shot_cot")
201
+ parser.add_argument("--data_question_key", type=str, default="question")
202
+ parser.add_argument("--data_answer_key", type=str, default="answer")
203
+
204
+ parser.add_argument("--sample_num", type=int, default=1)
205
+
206
+ parser.add_argument("--cuda_ind", type=int, default=0)
207
+ parser.add_argument("--tensor_parallel", type=int, default=1)
208
+ parser.add_argument("--cuda_start", type=int, default=0)
209
+ parser.add_argument("--cuda_num", type=int, default=8)
210
+
211
+ parser.add_argument("--load_in_8bit", type=str2bool, default=False)
212
+ parser.add_argument("--rewrite", type=str2bool, default=True)
213
+ parser.add_argument("--notlean", type=str2bool, default=True)
214
+
215
+ parser.add_argument("--use_typewriter", type=int, default=0)
216
+
217
+ parser.add_argument("--temperature", type=float, default=0.0)
218
+ parser.add_argument("--top_p", type=float, default=1)
219
+ parser.add_argument("--iter_max_new_tokens", type=int, default=512)
220
+ parser.add_argument("--init_max_new_tokens", type=int, default=2048)
221
+ parser.add_argument("--min_new_tokens", type=int, default=1)
222
+ parser.add_argument("--correct_response_format", type=str, default="The correct response is:")
223
+
224
+ args = parser.parse_args()
225
+ if args.evaluation_mode == 'generation':
226
+ if "lean" in args.dataset:
227
+ args.data_question_key = 'model_response'
228
+ args.data_answer_key = 'statement_poof'
229
+
230
+ if args.dataset == "lean4_5k_test":
231
+ args.data_path = "data/lean4_gpt_5k/test/data.jsonl"
232
+
233
+ elif args.dataset == "math_train":
234
+ args.data_path = "data/test/math/train.jsonl"
235
+
236
+ elif args.dataset == "gsm8k_train":
237
+ args.data_path = "data/test/gsm8k/train.jsonl"
238
+
239
+ elif args.dataset == "wild_test":
240
+ args.data_path = "/hpc2hdd/home/zyang398/data_2/wild_sample1k.jsonl"
241
+
242
+ elif args.dataset == "lean4_basic_test":
243
+ args.data_path = "data/lean4_basic/1k_test.jsonl"
244
+ elif args.dataset == "lean4_random_test":
245
+ args.data_path = "data/lean4_random/1k_test.json"
246
+ elif args.dataset == "lean4_random_first_train":
247
+ args.data_path = "data/lean4_random/5k_first.json"
248
+ elif args.dataset == "lean4_random_second_train":
249
+ args.data_path = "data/lean4_random/5k_second.json"
250
+ elif args.dataset == "lean4_random_third_train":
251
+ args.data_path = "data/lean4_random/5k_third.json"
252
+
253
+ if args.model_type == 'mistral_generator':
254
+ args.model_path = 'models/gsm8k/generators/mistral-ep2/'
255
+ elif args.model_type == 'mistral_generator_original':
256
+ args.model_path = '/data/OVM-Mistral-7b/mistral7b-ep2/'
257
+ elif args.model_type == 'gemma_generator':
258
+ args.model_path = 'models/gsm8k/generators/gemma2b2-ep2/'
259
+ elif args.model_type == 'phi2_generator':
260
+ args.model_path = 'models/gsm8k/generators/phi2b-ep2/'
261
+
262
+ elif args.model_type == 'mixtral':
263
+ args.model_path = '/data/Mixtral-8x7B-Instruct-v0.1'
264
+
265
+ elif args.model_type == 'mistral':
266
+ args.model_path = '/data/mistral-instruct'
267
+
268
+ elif args.model_type == 'qianwen-70b':
269
+ args.model_path = '/data/Qwen-72B-Chat'
270
+
271
+
272
+ elif args.model_type == 'llama2-7b-chat':
273
+ args.model_path = '/data/Llama-2-7b-chat/'
274
+
275
+ if args.cot_trigger_no == 1:
276
+ args.cot_trigger = "Let's think step by step."
277
+
278
+ return args
279
+
280
+
281
+ def create_demo_text(args, cot_flag, index_list):
282
+ # Concatenate demonstration examples ...
283
+ demo_text = ""
284
+ for i in index_list:
285
+ item = few_shot_list[i]
286
+ if cot_flag:
287
+ demo_text += "Q: " + item['question'] + "\nA: " + item['answer'] + " " + \
288
+ args.direct_answer_trigger_for_fewshot + " " + item['direct_answer'] + ".\n\n"
289
+ else:
290
+ demo_text += "Q: " + item['question'] + "\nA: " + \
291
+ args.direct_answer_trigger_for_fewshot + " " + item['direct_answer'] + ".\n\n"
292
+
293
+ return demo_text
294
+
295
+
296
+ def str2bool(s):
297
+ s = s.lower()
298
+ if s == 'true':
299
+ return True
300
+ elif s == 'false':
301
+ return False
302
+ else:
303
+ raise ValueError('invalid value: {}, must be true or false'.format(s))
304
+
305
+
306
+ def batchify(pairs, batch_size):
307
+
308
+ """将列表分成指定大小的批次"""
309
+ for i in range(0, len(pairs), batch_size):
310
+ yield pairs[i:i + batch_size]
311
+
312
+
313
+ def generate_prompts(questions, args):
314
+ """为每个问题生成提示"""
315
+ prompts = [generate_prompt_generation(args, question) for question in questions]
316
+ return prompts
317
+
318
+ PROMPT_DICT = {
319
+ "wild": (
320
+ "Statement and proof in natural language:\n\n"
321
+ "# Problem:\n{question}\n\n"
322
+ "# Proof:\n{answer}\n\n"
323
+ "Translate the statement and proof in natural language to lean4:"
324
+ ),
325
+ "lean4": (
326
+ "Statement and proof in natural language:\n\n"
327
+ "{statement_text}\n\n"
328
+ "Translate the statement and proof in natural language to lean4:"
329
+ ),
330
+ "prompt_no_input": (
331
+ "Below is an instruction that describes a task. "
332
+ "Write a response that appropriately completes the request.\n\n"
333
+ "### Instruction:\n{instruction}\n\n### Response:"
334
+ ),
335
+ }
336
+
337
+ def get_question_answer(args):
338
+ allfilepath = args.data_path
339
+ questions = []
340
+ answers = []
341
+
342
+ # Attempt to read the file as a regular JSON file
343
+ for filepath in allfilepath.split(','):
344
+ try:
345
+ with open(filepath, 'r') as file:
346
+ data = json.load(file)
347
+ # If the data is a list, assume it's an array of objects
348
+ if isinstance(data, list):
349
+ for json_item in data:
350
+ questions.append(json_item[args.data_question_key])
351
+ answers.append(json_item)
352
+ # If the data is a dict, assume it's a single object (or adjust logic as needed)
353
+ elif isinstance(data, dict):
354
+ questions.append(data[args.data_question_key])
355
+ answers.append(json_item)
356
+
357
+ except ValueError:
358
+ # If it fails, assume the file is in JSON Lines format
359
+ with open(filepath, 'r') as file:
360
+ for line in file:
361
+ json_item = json.loads(line)
362
+ questions.append(json_item[args.data_question_key])
363
+ answers.append(json_item)
364
+
365
+ if args.notlean :
366
+ questions = [ PROMPT_DICT['wild'].format(question= questions[id], answer =answers[id][args.data_answer_key] ) for id in range(len(questions))]
367
+
368
+ else:
369
+ questions = [ PROMPT_DICT['lean4'].format(statement_text = item) for item in questions]
370
+
371
+
372
+ return questions, answers
373
+
374
+
375
+ def main3(args):
376
+ from vllm import LLM, SamplingParams
377
+ import torch
378
+
379
+ model = LLM(model=args.model_path, dtype="bfloat16", trust_remote_code=True,
380
+ tensor_parallel_size=args.tensor_parallel, gpu_memory_utilization = 0.95)
381
+ print(args.model_path)
382
+
383
+ if "qianwen" in args.model_type:
384
+ model.llm_engine.tokenizer.eos_token_id = 151645
385
+ # model.llm_engine.tokenizer.pad_token_id = 151645
386
+ model.llm_engine.tokenizer.pad_token_id = None
387
+ # model.llm_engine.tokenizer.eos_token_id = None
388
+
389
+
390
+ print("load data")
391
+
392
+
393
+ questions, answers = get_question_answer(args)
394
+
395
+
396
+
397
+ question_exist_list = []
398
+ write_pattern = 'w' if args.rewrite else "a+"
399
+ if os.path.exists(args.output_dir) and not args.rewrite :
400
+ # 如果文件存在,从文件中读取数据加载到response_list
401
+ # Loop through each file that matches the pattern
402
+ file_pattern = os.path.join(args.output_dir, '[0-9]*.json')
403
+ for file_path in glob.glob(file_pattern):
404
+ # Open and read the JSON file
405
+ with open(file_path, 'r') as fp:
406
+ # Extract the 'question' field from each line and add it to the list
407
+ for line in fp.readlines():
408
+ question_exist_list.append(json.loads(line)['question'])
409
+ else:
410
+ try:
411
+ os.mkdir(args.output_dir)
412
+ except:
413
+ pass
414
+ qa_pairs = [(questions[idx], answers[idx]) for idx in range(len(questions)) if questions[idx] not in question_exist_list ]
415
+ cuda_pieces = np.array_split(range(len(qa_pairs)), args.cuda_num // args.tensor_parallel)
416
+ print(f"fitered {len(questions) - len(qa_pairs)} already")
417
+
418
+ with open(f"{args.output_dir}/{args.cuda_ind // args.tensor_parallel + args.cuda_start}.json", write_pattern,
419
+ encoding='utf-8') as wf:
420
+ start = cuda_pieces[args.cuda_start + args.cuda_ind // args.tensor_parallel][0]
421
+ end = cuda_pieces[args.cuda_start + args.cuda_ind // args.tensor_parallel][-1] + 1
422
+ subset_length = end - start
423
+ total_batches = (subset_length + args.batch_size - 1) // args.batch_size # Calculate the total number of batches
424
+ for batch in tqdm(batchify(qa_pairs[start:end], args.batch_size), total=total_batches):
425
+ questions, answers = zip(*batch) # 解压问题和答案
426
+ prompts = generate_prompts(questions, args)
427
+
428
+ with torch.no_grad():
429
+ output_all = []
430
+ try:
431
+ for i in range(args.sample_num):
432
+ sample_list = []
433
+ sampling_params = SamplingParams(temperature=args.temperature, top_p=args.top_p,
434
+ max_tokens=args.init_max_new_tokens)
435
+ generations = model.generate(prompts, sampling_params, use_tqdm=False)
436
+ for generation_output in generations:
437
+ output = generation_output.outputs[0].text
438
+ sample_list.append(output)
439
+ output_all.append(sample_list)
440
+
441
+ output_all = list(map(list, zip(*output_all)))
442
+ except Exception as e:
443
+ print(str(e))
444
+ exit
445
+ dicts = []
446
+ for question, answer, output, prompt in zip(questions, answers, output_all, prompts):
447
+ dicts.append({
448
+ "question": question,
449
+ "prompt": prompt,
450
+ "content": answer,
451
+ "total output": output,
452
+ })
453
+
454
+ for dict in dicts:
455
+ wf.writelines(json.dumps(dict, ensure_ascii=False) + '\n')
456
+
457
+ wf.flush()
458
+
459
+
460
+ def main(argv=None):
461
+ args = parse_arguments()
462
+ print('*****************************')
463
+ print(args)
464
+ print('*****************************')
465
+ if args.evaluation_mode == 'generation':
466
+ main3(args)
467
+ else:
468
+ raise ValueError("we do not yet inplement")
469
+
470
+
471
+ if __name__ == "__main__":
472
+ main()
model_train.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # Modified by Zheng Yuan and Hongyi Yuan
15
+
16
+ import os
17
+ import copy
18
+ import logging
19
+ from dataclasses import dataclass, field
20
+ from typing import Optional, Dict, Sequence
21
+ import io
22
+ import torch
23
+ import transformers
24
+ from torch.utils.data import Dataset
25
+ from transformers import Trainer
26
+ import argparse
27
+ import json
28
+ import random;
29
+
30
+ random.seed(42)
31
+
32
+
33
+ def _make_r_io_base(f, mode: str):
34
+ if not isinstance(f, io.IOBase):
35
+ f = open(f, mode=mode)
36
+ return f
37
+
38
+
39
+ def jload(f, mode="r"):
40
+ """Load a .json file into a dictionary."""
41
+ f = _make_r_io_base(f, mode)
42
+ jdict = json.load(f)
43
+ f.close()
44
+ return jdict
45
+
46
+
47
+ IGNORE_INDEX = -100
48
+ DEFAULT_PAD_TOKEN = "[PAD]"
49
+ DEFAULT_EOS_TOKEN = "</s>"
50
+ DEFAULT_BOS_TOKEN = "<s>"
51
+ DEFAULT_UNK_TOKEN = "<unk>"
52
+ PROMPT_DICT = {
53
+ "statement_form": (
54
+ "Statement in natural language:\n"
55
+ "{problem}\n"
56
+ "Translate the statement in natural language to Lean4:"
57
+ ),
58
+ "solver": (
59
+ "{statement_text}"
60
+ ),
61
+ "statementproof_inform": (
62
+ "Statement and proof in lean4:\n\n"
63
+ "{statement_text}\n\n"
64
+ "Translate the statement and proof in lean4 to natural language:"
65
+ ),
66
+ }
67
+ KEY_DICT = {
68
+ "statement_form" : ["problem", "statement"],
69
+ "solver": ["statement", "proof"],
70
+ "statementproof_inform": ["statement_poof", "model_response"]
71
+ }
72
+
73
+
74
+ #### 28
75
+ @dataclass
76
+ class ModelArguments:
77
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
78
+
79
+
80
+ @dataclass
81
+ class DataArguments:
82
+ data_path: str = field(default=None, metadata={"help": "Path to the training data."})
83
+
84
+
85
+ @dataclass
86
+ class TrainingArguments(transformers.TrainingArguments):
87
+ cache_dir: Optional[str] = field(default=None)
88
+ optim: str = field(default="adamw_torch")
89
+ model_max_length: int = field(
90
+ default=2048,
91
+ metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
92
+ )
93
+ overwrite_output_dir: bool = field(default=True)
94
+
95
+
96
+ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
97
+ """Collects the state dict and dump to disk."""
98
+ state_dict = trainer.model.state_dict()
99
+ if trainer.args.should_save:
100
+ cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
101
+ del state_dict
102
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
103
+
104
+
105
+ def smart_tokenizer_and_embedding_resize(
106
+ special_tokens_dict: Dict,
107
+ tokenizer: transformers.PreTrainedTokenizer,
108
+ model: transformers.PreTrainedModel,
109
+ ):
110
+ """Resize tokenizer and embedding.
111
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
112
+ """
113
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
114
+ model.resize_token_embeddings(len(tokenizer))
115
+
116
+ if num_new_tokens > 0:
117
+ input_embeddings = model.get_input_embeddings().weight.data
118
+ output_embeddings = model.get_output_embeddings().weight.data
119
+
120
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
121
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
122
+
123
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
124
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
125
+
126
+
127
+ def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
128
+ """Tokenize a list of strings."""
129
+ tokenized_list = [
130
+ tokenizer(
131
+ text,
132
+ return_tensors="pt",
133
+ padding="longest",
134
+ max_length=tokenizer.model_max_length,
135
+ truncation=True,
136
+ )
137
+ for text in strings
138
+ ]
139
+ input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
140
+ input_ids_lens = labels_lens = [
141
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
142
+ ]
143
+ return dict(
144
+ input_ids=input_ids,
145
+ labels=labels,
146
+ input_ids_lens=input_ids_lens,
147
+ labels_lens=labels_lens,
148
+ )
149
+
150
+
151
+ def preprocess(
152
+ sources: Sequence[str],
153
+ targets: Sequence[str],
154
+ tokenizer: transformers.PreTrainedTokenizer,
155
+ ) -> Dict:
156
+ """Preprocess the data by tokenizing."""
157
+ examples = [s + t for s, t in zip(sources, targets)]
158
+ examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
159
+ input_ids = examples_tokenized["input_ids"]
160
+ labels = copy.deepcopy(input_ids)
161
+ for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
162
+ label[:source_len] = IGNORE_INDEX
163
+ return dict(input_ids=input_ids, labels=labels)
164
+
165
+
166
+ class SupervisedDataset(Dataset):
167
+ """Dataset for supervised fine-tuning."""
168
+
169
+ def __init__(self, data_args, tokenizer: transformers.PreTrainedTokenizer):
170
+ super(SupervisedDataset, self).__init__()
171
+ logging.warning("Loading data...")
172
+ data_path = data_args.data_path
173
+ try:
174
+ data_path = data_path_map[data_path]
175
+ except:
176
+ data_path = data_path
177
+ list_data_dict = []
178
+ for item in data_path.split(','):
179
+ try:
180
+ list_data_dict += jload(item)
181
+
182
+ except BaseException:
183
+ with open(item, 'r') as f:
184
+ lines = f.readlines()
185
+ list_data_dict += [json.loads(line.strip()) for line in lines]
186
+
187
+ list_data_dict = random.sample(list_data_dict, len(list_data_dict))
188
+ list_data_dict = list_data_dict[:data_args.data_length]
189
+
190
+ logging.warning("Formatting inputs...")
191
+
192
+ # list_data_dict = [{'instruction':data['items'][0]['value'], 'input':'', 'output':data['items'][1]['value']} for data in list_data_dict]
193
+
194
+ list_data_dict = [{'instruction': PROMPT_DICT[data['task']].format(statement_text=data[KEY_DICT[data['task']][0]]), 'input': '',
195
+ 'output':data[KEY_DICT[data['task']][1]] } for data in list_data_dict]
196
+ print(f"len of {len(list_data_dict)}")
197
+ sources = [example['instruction'] for example in list_data_dict]
198
+ targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
199
+ # targets = [example['output'] for example in list_data_dict]
200
+
201
+ self.sources = sources
202
+ self.targets = targets
203
+
204
+ def __len__(self):
205
+ return len(self.sources)
206
+
207
+ def naive__getitem__(self, i) -> Dict[str, torch.Tensor]:
208
+ return dict(input_ids=self.input_ids[i], labels=self.labels[i])
209
+
210
+ def __getitem__(self, i):
211
+ return dict(input_ids=self.sources[i], labels=self.targets[i])
212
+
213
+
214
+ @dataclass
215
+ class DataCollatorForSupervisedDataset(object):
216
+ """Collate examples for supervised fine-tuning."""
217
+
218
+ tokenizer: transformers.PreTrainedTokenizer
219
+
220
+ def naive__call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
221
+ input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
222
+ input_ids = torch.nn.utils.rnn.pad_sequence(
223
+ input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
224
+ )
225
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
226
+ return dict(
227
+ input_ids=input_ids,
228
+ labels=labels,
229
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
230
+ )
231
+
232
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
233
+ sources = []
234
+ targets = []
235
+ for instance in instances:
236
+ source = instance['input_ids']
237
+ target = instance['labels']
238
+ sources.append(source)
239
+ targets.append(target)
240
+
241
+ data_dict = preprocess(sources, targets, self.tokenizer)
242
+ input_ids, labels = data_dict['input_ids'], data_dict['labels']
243
+ # input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
244
+ input_ids = torch.nn.utils.rnn.pad_sequence(
245
+ input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
246
+ )
247
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
248
+ return dict(
249
+ input_ids=input_ids,
250
+ labels=labels,
251
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
252
+ )
253
+
254
+
255
+ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
256
+ """Make dataset and collator for supervised fine-tuning."""
257
+ train_dataset = SupervisedDataset(tokenizer=tokenizer, data_args=data_args)
258
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
259
+ return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
260
+
261
+
262
+ os.environ["WANDB_PROJECT"] = "train_in_one_model"
263
+
264
+
265
+ def train():
266
+ parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
267
+ model_args, data_args, training_args, remaining_args = parser.parse_args_into_dataclasses(
268
+ return_remaining_strings=True)
269
+ data_args.data_length = int(remaining_args[1])
270
+
271
+ model = transformers.AutoModelForCausalLM.from_pretrained(
272
+ model_args.model_name_or_path,
273
+ cache_dir=training_args.cache_dir,
274
+ trust_remote_code=True,
275
+ torch_dtype=torch.bfloat16,
276
+ attn_implementation="flash_attention_2",
277
+ )
278
+
279
+ model.config.use_cache = False
280
+ model.gradient_checkpointing_enable()
281
+
282
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
283
+ model_args.model_name_or_path,
284
+ cache_dir=training_args.cache_dir,
285
+ model_max_length=training_args.model_max_length,
286
+ padding_side="right",
287
+ use_fast=False,
288
+ )
289
+ if tokenizer.pad_token is None:
290
+ smart_tokenizer_and_embedding_resize(
291
+ special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
292
+ tokenizer=tokenizer,
293
+ model=model,
294
+ )
295
+ if "llama" in model_args.model_name_or_path:
296
+ tokenizer.add_special_tokens(
297
+ {
298
+ "eos_token": DEFAULT_EOS_TOKEN,
299
+ "bos_token": DEFAULT_BOS_TOKEN,
300
+ "unk_token": DEFAULT_UNK_TOKEN,
301
+ }
302
+ )
303
+ try:
304
+ tokenizer.pad_token = tokenizer.unk_token
305
+ except:
306
+ pass
307
+
308
+ data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
309
+ trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
310
+ trainer.train()
311
+ model.config.use_cache = True
312
+ # trainer.save_state()
313
+ # if os.environ.get('LOCAL_RANK') == '0':
314
+ safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
315
+
316
+
317
+ if __name__ == "__main__":
318
+ train()
repl/.lake/packages/mathlib/scripts/align-import.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # This script was written by ChatGPT.
4
+ # https://chat.openai.com/share/e0363ebf-ed6f-4fd8-9b76-ebf422ed9f62
5
+
6
+ import re
7
+ import sys
8
+
9
+ def update_file_header(file_path):
10
+ with open(file_path, 'r') as f:
11
+ lines = f.readlines()
12
+
13
+ # Initialize variables
14
+ end_of_header_index = 0
15
+ source_module = ""
16
+ repo_ref = ""
17
+ commit_id = ""
18
+
19
+ # Lines to delete
20
+ delete_indices = []
21
+
22
+ for i, line in enumerate(lines):
23
+ # Check for the end of the "import" lines
24
+ if line.startswith('import'):
25
+ end_of_header_index = i
26
+ elif end_of_header_index != 0 and not line.startswith('import'):
27
+ break
28
+
29
+ # Extract the necessary info for the align import line and mark lines for deletion
30
+ if line.startswith('! This file was ported from'):
31
+ source_module = line.split()[-1]
32
+ delete_indices.append(i)
33
+ elif line.startswith('!') and 'commit' in line and commit_id == "":
34
+ split_line = line.split()
35
+ repo_ref = split_line[1]
36
+ commit_id = split_line[-1]
37
+ delete_indices.append(i)
38
+ elif line.startswith('!'):
39
+ delete_indices.append(i)
40
+ elif line == "\n" and lines[i+1].startswith("!"):
41
+ delete_indices.append(i)
42
+
43
+ # Only proceed if we have found the necessary info for the align import line
44
+ if source_module and repo_ref and commit_id:
45
+ # Generate the new line
46
+ new_line = f'\n#align_import {source_module} from "{repo_ref}"@"{commit_id}"\n'
47
+
48
+ # Delete the marked lines
49
+ for index in sorted(delete_indices, reverse=True):
50
+ del lines[index]
51
+
52
+ # Insert the new line after the "import" lines
53
+ lines.insert(end_of_header_index - len(delete_indices) + 1, new_line)
54
+
55
+ # Write the updated lines back to the file
56
+ with open(file_path, 'w') as f:
57
+ f.writelines(lines)
58
+
59
+ # The first command line argument is the file path
60
+ file_path = sys.argv[1]
61
+ update_file_header(file_path)
repl/.lake/packages/mathlib/scripts/align.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Tool to add source headers to ported theory files,
4
+ # archived for historical purposes.
5
+
6
+ from pathlib import Path
7
+ import re
8
+ import yaml
9
+
10
+ excepts = {
11
+ 'categorytheory.category.rel': 'categorytheory.category.relcat',
12
+ 'categorytheory.isomorphism': 'categorytheory.iso',
13
+ 'categorytheory.naturalisomorphism': 'categorytheory.natiso',
14
+ 'categorytheory.naturaltransformation': 'categorytheory.nattrans',
15
+ 'leancore.data.vector': 'data.vector',
16
+ 'order.monovary': 'order.monotone.monovary'
17
+ }
18
+
19
+ def condense(s):
20
+ if s.startswith('Mathlib/'):
21
+ s = s[len('Mathlib/'):]
22
+ if s.endswith('.lean'):
23
+ s = s[:-5]
24
+ s = s.lower()
25
+ s = s.replace('/', '.')
26
+ s = s.replace('_', '')
27
+ if s in excepts:
28
+ s = excepts[s]
29
+ return s
30
+
31
+ port_status = yaml.load(open("mathlib4-port-status.yaml").read())
32
+
33
+ # map from condensed names to mathlib4 paths
34
+ map = {}
35
+ for path in Path('Mathlib').glob('**/*.lean'):
36
+ path = str(path)
37
+ map[condense(path)] = path
38
+
39
+ count = 0
40
+ for key, val in port_status.items():
41
+ if val.startswith('Yes'):
42
+ sha = val.split()[2]
43
+ mathlib3 = key
44
+ mathlib4 = map[condense(key)]
45
+
46
+ place = '(\n-/\n\n?import )'
47
+ blob = "\n\n! This file was ported from Lean 3 source module " + mathlib3 + "\n" + \
48
+ "! leanprover-community/mathlib commit " + sha + "\n" + \
49
+ "! Please do not edit these lines, except to modify the commit id\n" + \
50
+ "! if you have ported upstream changes."
51
+ old = open(mathlib4).read()
52
+
53
+ if blob[1:] in old: # match even without leading newline
54
+ print(f'{mathlib4} already has header')
55
+ elif "! leanprover-community/mathlib commit " in old:
56
+ m = re.search("^! leanprover-community/mathlib commit (.*)$", old, flags=re.MULTILINE)
57
+ print(f'file says {m.groups()[0]} but we want {sha}')
58
+ assert(False)
59
+ else:
60
+ new = re.sub(place, blob + '\\1', old, flags=re.MULTILINE)
61
+ open(mathlib4, 'w').write(new)
62
+ count += 1
63
+
64
+ print(count)
repl/.lake/packages/mathlib/scripts/bench/accumulate_profile.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # sum up times of lines a la `elaboration 100ms`
3
+
4
+ import collections
5
+ import re
6
+ import sys
7
+
8
+ cats = collections.defaultdict(lambda: 0)
9
+ for line in sys.stdin:
10
+ sys.stderr.write(line)
11
+ if m := re.match("(.+?) ([\d.]+)(m?)s$", line):
12
+ cats[m[1].strip()] += float(m[2]) * (1e-3 if m[3] else 1)
13
+
14
+ for cat in sorted(cats.keys()):
15
+ cat2 = cat
16
+ if len(sys.argv) > 1:
17
+ cat2 = f"{sys.argv[1]} {cat}"
18
+ # default unit to `s`
19
+ if "|" not in cat2:
20
+ cat2 += "|s"
21
+ print(f"{cat2!r}: {cats[cat]:f}")
repl/.lake/packages/mathlib/scripts/detect_sha_changes.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script is called by a github action to verify that commit SHAs used in porting are valid.
3
+ It also produces links to the port-status webpage.
4
+
5
+ Note that only the first 10 annotations created with this action are guaranteed to appear, so we
6
+ produce the errors first.
7
+ """
8
+
9
+ import dataclasses
10
+ import re
11
+ import sys
12
+ from typing import Optional
13
+
14
+ import git
15
+
16
+ # upstream bug
17
+ git.Git.CatFileContentStream.__next__ = git.Git.CatFileContentStream.next
18
+
19
+ align_import_re = re.compile(
20
+ r'^#align_import ([^ ]*) from "(leanprover-community/[a-z]*)" ?@ ?"([0-9a-f]*)"')
21
+
22
+ @dataclasses.dataclass(eq=True, frozen=True)
23
+ class VersionInfo:
24
+ module: str
25
+ repo: Optional[str]
26
+ commit: Optional[str]
27
+ commit_line_no: Optional[int] = dataclasses.field(compare=False)
28
+
29
+ def to_commit(self):
30
+ try:
31
+ repo = git.Repo('port-repos/' + self.repo)
32
+ except git.exc.NoSuchPathError:
33
+ raise ValueError(f"Repo {self.repo} not recognized")
34
+ try:
35
+ repo.remotes.origin.fetch(self.commit, depth=1)
36
+ except Exception:
37
+ pass
38
+ return repo.commit(self.commit)
39
+
40
+ def get_mathlib4_module_commit_infos(contents):
41
+ for i, line in enumerate(contents, 1):
42
+ m = align_import_re.match(line)
43
+ if m:
44
+ module = m.group(1)
45
+ repo = m.group(2)
46
+ commit = m.group(3)
47
+ yield VersionInfo(module, repo, commit, i)
48
+
49
+ def get_mathlib4_module_commit_info_from_blob(blob: Optional[git.Blob]):
50
+ if blob is None:
51
+ return
52
+ yield from get_mathlib4_module_commit_infos(
53
+ l.decode('utf8') for l in blob.data_stream.stream)
54
+
55
+ def encode_msg_text_for_github(msg):
56
+ # even though this is probably url quoting, we match the implementation at
57
+ # https://github.com/actions/toolkit/blob/af821474235d3c5e1f49cee7c6cf636abb0874c4/packages/core/src/command.ts#L36-L94
58
+ return msg.replace('%', '%25').replace('\r', '%0D').replace('\n', '%0A')
59
+
60
+ if __name__ == '__main__':
61
+ repo = git.Repo('.')
62
+ base = repo.commit(sys.argv[1])
63
+ head = repo.commit(sys.argv[2])
64
+ any_errors = False
65
+
66
+ diff_infos = []
67
+ for diff in base.diff(head, paths=['Mathlib']):
68
+ a_info = set(get_mathlib4_module_commit_info_from_blob(diff.a_blob))
69
+ b_info = set(get_mathlib4_module_commit_info_from_blob(diff.b_blob))
70
+ if b_info <= a_info:
71
+ continue
72
+ diff_infos.append((diff, a_info, b_info))
73
+
74
+ all_refs = {}
75
+
76
+ # produce errors first
77
+ for diff, a_infos, b_infos in diff_infos:
78
+ for b_info in b_infos:
79
+ try:
80
+ b_info.to_commit()
81
+ except Exception as e:
82
+ print(f"::error file={diff.b_blob.path},line={b_info.commit_line_no},title=Invalid header::{encode_msg_text_for_github(str(e))}")
83
+ any_errors = True
84
+ continue
85
+
86
+ for diff, a_info, b_info in diff_infos:
87
+ same = a_info.intersection(b_info)
88
+ a_info -= same
89
+ b_info -= same
90
+ if a_info != {} and b_info != {}:
91
+ a_info_by_mod = {a.module: a for a in a_info}
92
+ b_info_by_mod = {b.module: b for b in b_info}
93
+ for k in set(a_info_by_mod.keys()) | set(b_info_by_mod.keys()):
94
+ a_info = a_info_by_mod.get(k, None)
95
+ b_info = b_info_by_mod.get(k, None)
96
+ if a_info is None or b_info is None:
97
+ pass
98
+ elif a_info.module == b_info.module:
99
+ mod_path = a_info.module.replace('.', '/')
100
+ msg = f"See review instructions and diff at\nhttps://leanprover-community.github.io/mathlib-port-status/file/{mod_path}?range={a_info.commit}..{b_info.commit}"
101
+ print(f"::notice file={diff.b_blob.path},line={b_info.commit_line_no},title=Synchronization::{encode_msg_text_for_github(msg)}")
102
+
103
+ if any_errors:
104
+ raise SystemExit("Setting a failure due to errors above")
repl/.lake/packages/mathlib/scripts/fix-comments.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import sys
4
+ import os
5
+ from pathlib import Path
6
+ import subprocess
7
+ import re
8
+
9
+ if len(sys.argv) != 2 or not sys.argv[1].endswith('.lean'):
10
+ print("usage: fix-comments.py X.lean")
11
+ sys.exit(1)
12
+
13
+ leanfile = sys.argv[1]
14
+
15
+ is_clean = subprocess.run(
16
+ ['git', 'status', '--untracked-files=no', '--porcelain'],
17
+ capture_output=True,
18
+ check=True,
19
+ encoding='utf-8').stdout.rstrip()
20
+
21
+ if is_clean != "":
22
+ print("Certain files tracked by git have uncommitted changes.\n")
23
+ os.system("git status --untracked-files=no")
24
+ print("\n")
25
+ s = input("Type y to continue. ")
26
+ if s != 'y':
27
+ sys.exit(1)
28
+
29
+ root_dir = subprocess.run(
30
+ ['git', 'rev-parse', '--show-toplevel'],
31
+ capture_output=True,
32
+ check=True,
33
+ encoding='utf-8').stdout.rstrip()
34
+
35
+ align_files = subprocess.run(
36
+ ['git', 'grep', '-l', '^#align'],
37
+ cwd=root_dir,
38
+ capture_output=True,
39
+ check=True,
40
+ encoding='utf-8')
41
+
42
+ name_map = dict()
43
+ for f in align_files.stdout.splitlines():
44
+ with open(os.path.join(root_dir, f), encoding="utf-8") as fh:
45
+ contents = fh.read()
46
+ for p in contents.split(sep='\n#align')[1:]:
47
+ n3, n4, *_ = p.split(maxsplit=2)
48
+ name_map[n3] = n4
49
+
50
+ def replace_names(s):
51
+ # Terrible hack to treat `.` as a word character
52
+ # (to match qualified names)
53
+ s = s.replace('.', 'Ᾰ')
54
+ # re.DOTALL means that `.` can also match a newline.
55
+ # `\A` and `\Z` match only at the start/end of the string respectively.
56
+ w = re.findall(r'(?:\b|\A).+?(?:\b|\Z)', s, flags=re.DOTALL)
57
+ for i in range(len(w)):
58
+ w[i] = w[i].replace('Ᾰ', '.')
59
+ w[i] = name_map.get(w[i], w[i])
60
+ return ''.join(w)
61
+
62
+ def process_backticked_names(s):
63
+ w = s.split(sep='`')
64
+ for i in range(len(w)):
65
+ if i % 2 == 1:
66
+ w[i] = replace_names(w[i])
67
+ return '`'.join(w)
68
+
69
+ rewritten_contents = ''
70
+
71
+ in_block_comment = False
72
+ in_line_comment = False
73
+ prev_char = None
74
+ comment_so_far = None # contains end marker but not begin marker
75
+
76
+ def finish_comment():
77
+ global rewritten_contents
78
+ global in_block_comment
79
+ global in_line_comment
80
+ global comment_so_far
81
+ if comment_so_far is not None:
82
+ rewritten_contents += process_backticked_names(comment_so_far)
83
+ in_block_comment = False
84
+ in_line_comment = False
85
+ comment_so_far = None
86
+
87
+ with open(leanfile, encoding="utf-8") as F:
88
+ while 1:
89
+ char = F.read(1)
90
+ if not char:
91
+ finish_comment()
92
+ break
93
+
94
+ if in_block_comment or in_line_comment:
95
+ comment_so_far = comment_so_far + char
96
+ else:
97
+ rewritten_contents += char
98
+
99
+ if in_block_comment and prev_char == '-' and char == '/':
100
+ finish_comment()
101
+
102
+ if in_line_comment and char == '\n':
103
+ finish_comment()
104
+
105
+ if comment_so_far is None and prev_char == '/' and char == '-':
106
+ in_block_comment = True
107
+ comment_so_far = ''
108
+
109
+ if comment_so_far is None and prev_char == '-' and char == '-':
110
+ in_line_comment = True
111
+ comment_so_far = ''
112
+
113
+ prev_char = char
114
+
115
+ def mktree(path, sha, tree=True):
116
+ if path == Path('.'):
117
+ return sha
118
+ if tree:
119
+ inp = f"040000 tree {sha}\t{path.name}"
120
+ else:
121
+ inp = f"100644 blob {sha}\t{path.name}"
122
+ tree_sha = subprocess.run(
123
+ ['git', 'mktree'],
124
+ cwd=root_dir,
125
+ input=inp,
126
+ capture_output=True,
127
+ check=True,
128
+ encoding='utf8').stdout.rstrip()
129
+ return mktree(path.parent, tree_sha)
130
+
131
+ path = Path(subprocess.run(
132
+ ['git', 'ls-files', '--full-name', leanfile],
133
+ capture_output=True,
134
+ check=True,
135
+ encoding='utf-8').stdout.rstrip())
136
+
137
+ blob_sha = subprocess.run(
138
+ ['git', 'hash-object', '-w', '--stdin'],
139
+ input=rewritten_contents,
140
+ cwd=root_dir,
141
+ capture_output=True,
142
+ check=True,
143
+ encoding='utf-8').stdout.rstrip()
144
+
145
+ tree_sha = mktree(path, blob_sha, tree=False)
146
+
147
+ print(f"The script will now interactively suggest changes to {leanfile}.\n")
148
+ s = input("Type y to continue. ")
149
+ if s != 'y':
150
+ sys.exit(1)
151
+
152
+ subprocess.run(['git', 'restore', '--patch', '--source=' + tree_sha, '--', leanfile], check=True)
153
+
154
+ r = subprocess.run(['git', 'diff', '--quiet', leanfile])
155
+ if r.returncode == 0:
156
+ pass
157
+ elif r.returncode == 1: # file was changed
158
+ print("\nPerhaps you would now like to run:")
159
+ print(f"git add {leanfile} && git commit -m 'auto: naming'")
160
+ else:
161
+ # something went wrong
162
+ r.check_returncode()
repl/.lake/packages/mathlib/scripts/fix-line-breaks.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import sys
3
+ from collections import deque
4
+
5
+ lns = deque([], 2)
6
+ with open(sys.argv[1], "r", encoding="utf-8", newline="\n") as f, \
7
+ open(sys.argv[2], "w", encoding="utf-8", newline="\n") as g:
8
+ for ln_raw in f:
9
+ ln = ln_raw.strip("\n")
10
+ lns.append(ln)
11
+ if len(lns) <= 1:
12
+ continue
13
+ if lns[1].lstrip() == "by" and len(lns[0]) < 98 and not lns[0].lstrip().startswith("--"):
14
+ lns.pop()
15
+ lns[0] += " by"
16
+ elif lns[1].lstrip() == "where" and len(lns[0]) < 95 and not lns[0].lstrip().startswith("--"):
17
+ lns.pop()
18
+ lns[0] += " where"
19
+ else:
20
+ print(lns[0], file=g)
21
+ lns.popleft()
22
+ for ln in lns:
23
+ print(ln, file=g)
repl/.lake/packages/mathlib/scripts/fix-lints.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import sys
4
+ import os
5
+ import subprocess
6
+ import shutil
7
+
8
+ def getpos(line):
9
+ _, line, col, _ = line.split(sep=':', maxsplit=3)
10
+ return int(line), int(col)
11
+
12
+ if len(sys.argv) != 2 or not sys.argv[1].endswith('.lean'):
13
+ print("usage: fix-lints.py Mathlib/A/B/C.lean")
14
+ sys.exit(1)
15
+
16
+ leanfile = sys.argv[1]
17
+ leanmodule = leanfile[:-5].replace('/', '.')
18
+
19
+ # try to build
20
+ log = subprocess.run(
21
+ ['lake', 'build', leanmodule],
22
+ capture_output=True, encoding='utf8')
23
+ if log.returncode == 0:
24
+ print("no errors 🎉")
25
+ exit(0)
26
+
27
+ shutil.copyfile(leanfile, leanfile + '.bak')
28
+
29
+ with open(leanfile + '.bak', encoding='utf8') as fp:
30
+ f = list(fp)
31
+ count = 0
32
+ for l in reversed(log.stderr.splitlines()):
33
+ if 'linter.unusedVariables' in l:
34
+ line, col = getpos(l)
35
+ f[line-1] = f[line-1][0:col] + '_' + f[line-1][col:]
36
+ count += 1
37
+ elif 'linter.unnecessarySeqFocus' in l:
38
+ line, col = getpos(l)
39
+ f[line-1] = f[line-1][0:col].rstrip() + ';' + f[line-1][col+3:]
40
+ count += 1
41
+ else:
42
+ print(l, file=sys.stderr)
43
+
44
+ print(f'Fixed {count} warnings', file=sys.stderr)
45
+
46
+ with open(leanfile, 'w', encoding='utf8') as fp:
47
+ fp.write(''.join(f))
48
+ os.remove(leanfile + '.bak')
repl/.lake/packages/mathlib/scripts/lint-style.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Lint a file or files from mathlib for style.
4
+
5
+ Sample usage:
6
+
7
+ $ ./scripts/lint-style.py $(find Mathlib -name '*.lean')
8
+
9
+ which will lint all of the Lean files in the specified directories.
10
+
11
+ The resulting error output will contain one line for each style error
12
+ encountered that isn't in the list of allowed / ignored style exceptions.
13
+
14
+ Paths with no errors will not appear in the output, and the script will
15
+ exit with successful return code if there are no errors encountered in
16
+ any provided paths.
17
+
18
+ Paths emitted in the output will match the paths provided on the
19
+ command line for any files containing errors -- in particular, linting
20
+ a relative path (like ``Mathlib/Foo/Bar.lean``) will produce errors
21
+ that contain the relative path, whilst linting absolute paths (like
22
+ ``/root/mathlib4/Mathlib/Foo/Bar.lean``) will produce errors with the
23
+ absolute path.
24
+
25
+ This script can also be used to regenerate the list of allowed / ignored style
26
+ exceptions by redirecting the output to ``style-exceptions.txt``. Use:
27
+
28
+ $ ./scripts/update-style-exceptions.sh
29
+
30
+ to perform this update.
31
+ """
32
+
33
+ # TODO: This is adapted from the linter for mathlib3. It should be rewritten in Lean.
34
+
35
+ from pathlib import Path
36
+ import sys
37
+ import re
38
+ import shutil
39
+
40
+ ERR_COP = 0 # copyright header
41
+ ERR_MOD = 2 # module docstring
42
+ ERR_LIN = 3 # line length
43
+ ERR_OPT = 6 # set_option
44
+ ERR_AUT = 7 # malformed authors list
45
+ ERR_TAC = 9 # imported Mathlib.Tactic
46
+ ERR_IBY = 11 # isolated by
47
+ ERR_DOT = 12 # isolated or low focusing dot
48
+ ERR_SEM = 13 # the substring " ;"
49
+ ERR_WIN = 14 # Windows line endings "\r\n"
50
+ ERR_TWS = 15 # trailing whitespace
51
+ ERR_CLN = 16 # line starts with a colon
52
+ ERR_IND = 17 # second line not correctly indented
53
+ ERR_ARR = 18 # space after "←"
54
+ ERR_NUM_LIN = 19 # file is too large
55
+ ERR_NSP = 20 # non-terminal simp
56
+
57
+ exceptions = []
58
+
59
+ SCRIPTS_DIR = Path(__file__).parent.resolve()
60
+ ROOT_DIR = SCRIPTS_DIR.parent
61
+
62
+
63
+ with SCRIPTS_DIR.joinpath("style-exceptions.txt").open(encoding="utf-8") as f:
64
+ for exline in f:
65
+ filename, _, _, _, _, errno, *extra = exline.split()
66
+ path = ROOT_DIR / filename
67
+ if errno == "ERR_COP":
68
+ exceptions += [(ERR_COP, path, None)]
69
+ elif errno == "ERR_MOD":
70
+ exceptions += [(ERR_MOD, path, None)]
71
+ elif errno == "ERR_LIN":
72
+ exceptions += [(ERR_LIN, path, None)]
73
+ elif errno == "ERR_OPT":
74
+ exceptions += [(ERR_OPT, path, None)]
75
+ elif errno == "ERR_AUT":
76
+ exceptions += [(ERR_AUT, path, None)]
77
+ elif errno == "ERR_TAC":
78
+ exceptions += [(ERR_TAC, path, None)]
79
+ elif errno == "ERR_NUM_LIN":
80
+ exceptions += [(ERR_NUM_LIN, path, extra[1])]
81
+ else:
82
+ print(f"Error: unexpected errno in style-exceptions.txt: {errno}")
83
+ sys.exit(1)
84
+
85
+ new_exceptions = False
86
+
87
+ def annotate_comments(enumerate_lines):
88
+ """
89
+ Take a list of tuples of enumerated lines of the form
90
+ (line_number, line, ...)
91
+ and return a list of
92
+ (line_number, line, ..., True/False)
93
+ where lines have True attached when they are in comments.
94
+ """
95
+ nesting_depth = 0 # We're in a comment when `nesting_depth > 0`.
96
+ starts_in_comment = False # Whether we're in a comment when starting the line.
97
+ for line_nr, line, *rem in enumerate_lines:
98
+ # We assume multiline comments do not begin or end within single-line comments.
99
+ if line == "\n" or line.lstrip().startswith("--"):
100
+ yield line_nr, line, *rem, True
101
+ continue
102
+ # We assume that "/-/" and "-/-" never occur outside of "--" comments.
103
+ # We assume that we do not encounter "... -/ <term> /- ...".
104
+ # We also don't account for "/-" and "-/" appearing in strings.
105
+ starts_in_comment = (nesting_depth > 0)
106
+ nesting_depth = nesting_depth + line.count("/-") - line.count("-/")
107
+ in_comment = (starts_in_comment or line.lstrip().startswith("/-")) and \
108
+ (nesting_depth > 0 or line.rstrip().endswith("-/"))
109
+ yield line_nr, line, *rem, in_comment
110
+
111
+ def annotate_strings(enumerate_lines):
112
+ """
113
+ Take a list of tuples of enumerated lines of the form
114
+ (line_number, line, ...)
115
+ and return a list of
116
+ (line_number, line, ..., True/False)
117
+ where lines have True attached when they are in strings.
118
+ """
119
+ in_string = False
120
+ in_comment = False
121
+ for line_nr, line, *rem in enumerate_lines:
122
+ # ignore comment markers inside string literals
123
+ if not in_string:
124
+ if "/-" in line:
125
+ in_comment = True
126
+ if "-/" in line:
127
+ in_comment = False
128
+ # ignore quotes inside comments
129
+ if not in_comment:
130
+ # crude heuristic: if the number of non-escaped quote signs is odd,
131
+ # we're starting / ending a string literal
132
+ if line.count("\"") - line.count("\\\"") % 2 == 1:
133
+ in_string = not in_string
134
+ # if there are quote signs in this line,
135
+ # a string literal probably begins and / or ends here,
136
+ # so we skip this line
137
+ if line.count("\"") > 0:
138
+ yield line_nr, line, *rem, True
139
+ continue
140
+ if in_string:
141
+ yield line_nr, line, *rem, True
142
+ continue
143
+ yield line_nr, line, *rem, False
144
+
145
+ def set_option_check(lines, path):
146
+ errors = []
147
+ newlines = []
148
+ for line_nr, line, in_comment, in_string in annotate_strings(annotate_comments(lines)):
149
+ if line.strip().startswith('set_option') and not in_comment and not in_string:
150
+ option_prefix = line.strip().split(' ', 2)[1].split('.', 1)[0]
151
+ # forbidden options: pp, profiler, trace
152
+ if option_prefix in {'pp', 'profiler', 'trace'}:
153
+ errors += [(ERR_OPT, line_nr, path)]
154
+ # skip adding this line to newlines so that we suggest removal
155
+ continue
156
+ newlines.append((line_nr, line))
157
+ return errors, newlines
158
+
159
+ def line_endings_check(lines, path):
160
+ errors = []
161
+ newlines = []
162
+ for line_nr, line in lines:
163
+ if "\r\n" in line:
164
+ errors += [(ERR_WIN, line_nr, path)]
165
+ line = line.replace("\r\n", "\n")
166
+ if line.endswith(" \n"):
167
+ errors += [(ERR_TWS, line_nr, path)]
168
+ line = line.rstrip() + "\n"
169
+ newlines.append((line_nr, line))
170
+ return errors, newlines
171
+
172
+ def four_spaces_in_second_line(lines, path):
173
+ # TODO: also fix the space for all lines before ":=", right now we only fix the line after
174
+ # the first line break
175
+ errors = []
176
+ # We never alter the first line, as it does not occur as next_line in the iteration over the
177
+ # zipped lines below, hence we add it here
178
+ newlines = [lines[0]]
179
+ annotated_lines = list(annotate_comments(lines))
180
+ for (_, line, is_comment), (next_line_nr, next_line, _) in zip(annotated_lines,
181
+ annotated_lines[1:]):
182
+ # Check if the current line matches "(lemma|theorem) .* :"
183
+ new_next_line = next_line
184
+ if (not is_comment) and re.search(r"^(protected )?(def|lemma|theorem) (?!.*:=).*(where)?$",
185
+ line):
186
+ # Calculate the number of spaces before the first non-space character in the next line
187
+ stripped_next_line = next_line.lstrip()
188
+ if not (next_line == '\n' or next_line.startswith("#") or stripped_next_line.startswith("--")):
189
+ num_spaces = len(next_line) - len(stripped_next_line)
190
+ # The match with "| " could potentially match with a different usage of the same
191
+ # symbol, e.g. some sort of norm. In that case a space is not necessary, so
192
+ # looking for "| " should be enough.
193
+ if stripped_next_line.startswith("| ") or line.endswith("where\n"):
194
+ # Check and fix if the number of leading space is not 2
195
+ if num_spaces != 2:
196
+ errors += [(ERR_IND, next_line_nr, path)]
197
+ new_next_line = ' ' * 2 + stripped_next_line
198
+ # Check and fix if the number of leading spaces is not 4
199
+ else:
200
+ if num_spaces != 4:
201
+ errors += [(ERR_IND, next_line_nr, path)]
202
+ new_next_line = ' ' * 4 + stripped_next_line
203
+ newlines.append((next_line_nr, new_next_line))
204
+ return errors, newlines
205
+
206
+ def nonterminal_simp_check(lines, path):
207
+ errors = []
208
+ newlines = []
209
+ annotated_lines = list(annotate_comments(lines))
210
+ for (line_nr, line, is_comment), (_, next_line, _) in zip(annotated_lines,
211
+ annotated_lines[1:]):
212
+ # Check if the current line matches whitespace followed by "simp"
213
+ new_line = line
214
+ # TODO it would be better to use a regex like r"^\s*simp( \[.*\])?( at .*)?$" and thereby
215
+ # catch all possible simp invocations. Adding this will require more initial cleanup or
216
+ # nolint.
217
+ if (not is_comment) and re.search(r"^\s*simp$", line):
218
+ # Calculate the number of spaces before the first non-space character in the line
219
+ num_spaces = len(line) - len(line.lstrip())
220
+ # Calculate the number of spaces before the first non-space character in the next line
221
+ stripped_next_line = next_line.lstrip()
222
+ if not (next_line == '\n' or next_line.startswith("#") or stripped_next_line.startswith("--") or "rfl" in next_line):
223
+ num_next_spaces = len(next_line) - len(stripped_next_line)
224
+ # Check if the number of leading spaces is the same
225
+ if num_spaces == num_next_spaces:
226
+ # If so, the simp is nonterminal
227
+ errors += [(ERR_NSP, line_nr, path)]
228
+ new_line = line.replace("simp", "simp?")
229
+ newlines.append((line_nr, new_line))
230
+ newlines.append(lines[-1])
231
+ return errors, newlines
232
+
233
+ def long_lines_check(lines, path):
234
+ errors = []
235
+ # TODO: find a good way to break long lines
236
+ # TODO: some string literals (in e.g. tactic output messages) can be excepted from this rule
237
+ for line_nr, line in lines:
238
+ if "http" in line or "#align" in line:
239
+ continue
240
+ if len(line) > 101:
241
+ errors += [(ERR_LIN, line_nr, path)]
242
+ return errors, lines
243
+
244
+ def import_only_check(lines, path):
245
+ for _, line, is_comment in annotate_comments(lines):
246
+ if is_comment:
247
+ continue
248
+ imports = line.split()
249
+ if imports[0] == "#align_import":
250
+ continue
251
+ if imports[0] != "import":
252
+ return False
253
+ return True
254
+
255
+ def regular_check(lines, path):
256
+ errors = []
257
+ copy_started = False
258
+ copy_done = False
259
+ copy_start_line_nr = 1
260
+ copy_lines = ""
261
+ for line_nr, line in lines:
262
+ if not copy_started and line == "\n":
263
+ errors += [(ERR_COP, copy_start_line_nr, path)]
264
+ continue
265
+ if not copy_started and line == "/-\n":
266
+ copy_started = True
267
+ copy_start_line_nr = line_nr
268
+ continue
269
+ if not copy_started:
270
+ errors += [(ERR_COP, line_nr, path)]
271
+ if copy_started and not copy_done:
272
+ copy_lines += line
273
+ if "Author" in line:
274
+ # Validating names is not a reasonable thing to do,
275
+ # so we just look for the two common variations:
276
+ # using ' and ' between names, and a '.' at the end of line.
277
+ if ((not line.startswith("Authors: ")) or
278
+ (" " in line) or
279
+ (" and " in line) or
280
+ (line[-2] == '.')):
281
+ errors += [(ERR_AUT, line_nr, path)]
282
+ if line == "-/\n":
283
+ if ((not "Copyright" in copy_lines) or
284
+ (not "Apache" in copy_lines) or
285
+ (not "Authors: " in copy_lines)):
286
+ errors += [(ERR_COP, copy_start_line_nr, path)]
287
+ copy_done = True
288
+ continue
289
+ if copy_done and line == "\n":
290
+ continue
291
+ words = line.split()
292
+ if words[0] != "import" and words[0] != "--" and words[0] != "/-!" and words[0] != "#align_import":
293
+ errors += [(ERR_MOD, line_nr, path)]
294
+ break
295
+ if words[0] == "/-!":
296
+ break
297
+ return errors, lines
298
+
299
+ def banned_import_check(lines, path):
300
+ errors = []
301
+ for line_nr, line, is_comment in annotate_comments(lines):
302
+ if is_comment:
303
+ continue
304
+ imports = line.split()
305
+ if imports[0] != "import":
306
+ break
307
+ if imports[1] in ["Mathlib.Tactic"]:
308
+ errors += [(ERR_TAC, line_nr, path)]
309
+ return errors, lines
310
+
311
+ def isolated_by_dot_semicolon_check(lines, path):
312
+ errors = []
313
+ newlines = []
314
+ for line_nr, line in lines:
315
+ if line.strip() == "by":
316
+ # We excuse those "by"s following a comma or ", fun ... =>", since generally hanging "by"s
317
+ # should not be used in the second or later arguments of a tuple/anonymous constructor
318
+ # See https://github.com/leanprover-community/mathlib4/pull/3825#discussion_r1186702599
319
+ prev_line = lines[line_nr - 2][1].rstrip()
320
+ if not prev_line.endswith(",") and not re.search(", fun [^,]* (=>|↦)$", prev_line):
321
+ errors += [(ERR_IBY, line_nr, path)]
322
+ if line.lstrip().startswith(". "):
323
+ errors += [(ERR_DOT, line_nr, path)]
324
+ line = line.replace(". ", "· ", 1)
325
+ if line.strip() in (".", "·"):
326
+ errors += [(ERR_DOT, line_nr, path)]
327
+ if " ;" in line:
328
+ errors += [(ERR_SEM, line_nr, path)]
329
+ line = line.replace(" ;", ";")
330
+ if line.lstrip().startswith(":"):
331
+ errors += [(ERR_CLN, line_nr, path)]
332
+ newlines.append((line_nr, line))
333
+ return errors, newlines
334
+
335
+ def left_arrow_check(lines, path):
336
+ errors = []
337
+ newlines = []
338
+ for line_nr, line, is_comment, in_string in annotate_strings(annotate_comments(lines)):
339
+ if is_comment or in_string:
340
+ newlines.append((line_nr, line))
341
+ continue
342
+ # Allow "←" to be followed by "%" or "`", but not by "`(" or "``(" (since "`()" and "``()"
343
+ # are used for syntax quotations). Otherwise, insert a space after "←".
344
+ new_line = re.sub(r'←(?:(?=``?\()|(?![%`]))(\S)', r'← \1', line)
345
+ if new_line != line:
346
+ errors += [(ERR_ARR, line_nr, path)]
347
+ newlines.append((line_nr, new_line))
348
+ return errors, newlines
349
+
350
+ def output_message(path, line_nr, code, msg):
351
+ if len(exceptions) == 0:
352
+ # we are generating a new exceptions file
353
+ # filename first, then line so that we can call "sort" on the output
354
+ print(f"{path} : line {line_nr} : {code} : {msg}")
355
+ else:
356
+ if code.startswith("ERR"):
357
+ msg_type = "error"
358
+ if code.startswith("WRN"):
359
+ msg_type = "warning"
360
+ # We are outputting for github. We duplicate path, line_nr and code,
361
+ # so that they are also visible in the plaintext output.
362
+ print(f"::{msg_type} file={path},line={line_nr},code={code}::{path}#L{line_nr}: {code}: {msg}")
363
+
364
+ def format_errors(errors):
365
+ global new_exceptions
366
+ for errno, line_nr, path in errors:
367
+ if (errno, path.resolve(), None) in exceptions:
368
+ continue
369
+ new_exceptions = True
370
+ if errno == ERR_COP:
371
+ output_message(path, line_nr, "ERR_COP", "Malformed or missing copyright header")
372
+ if errno == ERR_MOD:
373
+ output_message(path, line_nr, "ERR_MOD", "Module docstring missing, or too late")
374
+ if errno == ERR_LIN:
375
+ output_message(path, line_nr, "ERR_LIN", "Line has more than 100 characters")
376
+ if errno == ERR_OPT:
377
+ output_message(path, line_nr, "ERR_OPT", "Forbidden set_option command")
378
+ if errno == ERR_AUT:
379
+ output_message(path, line_nr, "ERR_AUT", "Authors line should look like: 'Authors: Jean Dupont, Иван Иванович Иванов'")
380
+ if errno == ERR_TAC:
381
+ output_message(path, line_nr, "ERR_TAC", "Files in mathlib cannot import the whole tactic folder")
382
+ if errno == ERR_IBY:
383
+ output_message(path, line_nr, "ERR_IBY", "Line is an isolated 'by'")
384
+ if errno == ERR_DOT:
385
+ output_message(path, line_nr, "ERR_DOT", "Line is an isolated focusing dot or uses . instead of ·")
386
+ if errno == ERR_SEM:
387
+ output_message(path, line_nr, "ERR_SEM", "Line contains a space before a semicolon")
388
+ if errno == ERR_WIN:
389
+ output_message(path, line_nr, "ERR_WIN", "Windows line endings (\\r\\n) detected")
390
+ if errno == ERR_TWS:
391
+ output_message(path, line_nr, "ERR_TWS", "Trailing whitespace detected on line")
392
+ if errno == ERR_CLN:
393
+ output_message(path, line_nr, "ERR_CLN", "Put : and := before line breaks, not after")
394
+ if errno == ERR_IND:
395
+ output_message(path, line_nr, "ERR_IND", "If the theorem/def statement requires multiple lines, indent it correctly (4 spaces or 2 for `|`)")
396
+ if errno == ERR_ARR:
397
+ output_message(path, line_nr, "ERR_ARR", "Missing space after '←'.")
398
+ if errno == ERR_NSP:
399
+ output_message(path, line_nr, "ERR_NSP", "Non-terminal simp. Replace with `simp?` and use the suggested output")
400
+
401
+ def lint(path, fix=False):
402
+ global new_exceptions
403
+ with path.open(encoding="utf-8", newline="") as f:
404
+ # We enumerate the lines so that we can report line numbers in the error messages correctly
405
+ # we will modify lines as we go, so we need to keep track of the original line numbers
406
+ lines = f.readlines()
407
+ enum_lines = enumerate(lines, 1)
408
+ newlines = enum_lines
409
+ for error_check in [line_endings_check,
410
+ four_spaces_in_second_line,
411
+ long_lines_check,
412
+ isolated_by_dot_semicolon_check,
413
+ set_option_check,
414
+ left_arrow_check,
415
+ nonterminal_simp_check]:
416
+ errs, newlines = error_check(newlines, path)
417
+ format_errors(errs)
418
+
419
+ if not import_only_check(newlines, path):
420
+ # Check for too long files: either longer than 1500 lines, or not covered by an exception.
421
+ # Each exception contains a "watermark". If the file is longer than that, we also complain.
422
+ if len(lines) > 1500:
423
+ ex = [e for e in exceptions if e[1] == path.resolve()]
424
+ if ex:
425
+ (_ERR_NUM, _path, watermark) = list(ex)[0]
426
+ assert int(watermark) > 500 # protect against parse error
427
+ is_too_long = len(lines) > int(watermark)
428
+ else:
429
+ is_too_long = True
430
+ if is_too_long:
431
+ new_exceptions = True
432
+ # add up to 200 lines of slack, so simple PRs don't trigger this right away
433
+ watermark = len(lines) // 100 * 100 + 200
434
+ output_message(path, 1, "ERR_NUM_LIN", f"{watermark} file contains {len(lines)} lines, try to split it up")
435
+ errs, newlines = regular_check(newlines, path)
436
+ format_errors(errs)
437
+ errs, newlines = banned_import_check(newlines, path)
438
+ format_errors(errs)
439
+ # if we haven't been asked to fix errors, or there are no errors or no fixes, we're done
440
+ if fix and new_exceptions and enum_lines != newlines:
441
+ path.with_name(path.name + '.bak').write_text("".join(l for _,l in newlines), encoding = "utf8")
442
+ shutil.move(path.with_name(path.name + '.bak'), path)
443
+
444
+ fix = "--fix" in sys.argv
445
+ argv = (arg for arg in sys.argv[1:] if arg != "--fix")
446
+
447
+ for filename in argv:
448
+ lint(Path(filename), fix=fix)
449
+
450
+ if new_exceptions:
451
+ exit(1)
repl/.lake/packages/mathlib/scripts/make_port_status.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import pytz
4
+ import datetime
5
+ import github
6
+ import os
7
+ import re
8
+ import requests
9
+ import subprocess
10
+ import sys
11
+ import yaml
12
+ import networkx as nx
13
+ from collections import defaultdict
14
+ from pathlib import Path
15
+
16
+ # Must run from root of mathlib4 directory.
17
+
18
+ if not os.path.exists('port-repos/mathlib'):
19
+ print("Make sure you are in the root of the mathlib4 directory")
20
+ print("and have checked out mathlib under port-repos/mathlib.")
21
+ sys.exit(1)
22
+
23
+ GITHUB_TOKEN_FILE = 'port-repos/github-token'
24
+ github_token = open(GITHUB_TOKEN_FILE).read().strip()
25
+
26
+ mathlib3_root = 'port-repos/mathlib'
27
+ mathlib4_root = './'
28
+
29
+ source_module_re = re.compile(r"^! .*source module (.*)$")
30
+ commit_re = re.compile(r"^! (leanprover-community/[a-z]*) commit ([0-9a-f]*)")
31
+ import_re = re.compile(r"^import ([^ ]*)")
32
+
33
+ align_import_re = re.compile(
34
+ r'^#align_import ([^ ]*) from "(leanprover-community/[a-z]*)" ?@ ?"([0-9a-f]*)"')
35
+
36
+ def mk_label(path: Path) -> str:
37
+ rel = path.relative_to(Path(mathlib3_root))
38
+ rel = Path(*rel.parts[1:])
39
+ return str(rel.with_suffix('')).replace(os.sep, '.')
40
+
41
+ paths = []
42
+ for path in Path(mathlib3_root).glob('**/*.lean'):
43
+ if path.relative_to(mathlib3_root).parts[0] not in ['src', 'archive', 'counterexamples']:
44
+ continue
45
+ if path.relative_to(mathlib3_root).parts[1] in ['tactic', 'meta']:
46
+ continue
47
+ paths.append(path)
48
+
49
+ graph = nx.DiGraph()
50
+ for path in paths:
51
+ graph.add_node(mk_label(path))
52
+
53
+ for path in paths:
54
+ label = mk_label(path)
55
+ for line in path.read_text().split('\n'):
56
+ m = import_re.match(line)
57
+ if m:
58
+ imported = m.group(1)
59
+ if imported.startswith('tactic.') or imported.startswith('meta.') or imported.startswith('.'):
60
+ continue
61
+ if imported not in graph.nodes:
62
+ if imported + '.default' in graph.nodes:
63
+ imported = imported + '.default'
64
+ else:
65
+ imported = imported
66
+ graph.add_edge(imported, label)
67
+
68
+ def get_mathlib4_module_commit_info(contents):
69
+ module = repo = commit = None
70
+ for line in contents.split('\n'):
71
+ m = align_import_re.match(line)
72
+ if m:
73
+ module = m.group(1)
74
+ repo = m.group(2)
75
+ commit = m.group(3)
76
+ break
77
+ m = source_module_re.match(line)
78
+ if m:
79
+ module = m.group(1)
80
+ m = commit_re.match(line)
81
+ if m:
82
+ repo = m.group(1)
83
+ commit = m.group(2)
84
+ return module, repo, commit
85
+
86
+ # contains ported files
87
+ # lean 3 module name -> { mathlib4_file, mathlib3_hash }
88
+ data = dict()
89
+ for path4 in Path(mathlib4_root).glob('**/*.lean'):
90
+ # we definitely do not want to look in `port-repos` here!
91
+ if path4.relative_to(mathlib4_root).parts[0] not in ('Mathlib', 'Archive', 'Counterexamples'):
92
+ continue
93
+ module, repo, commit = get_mathlib4_module_commit_info(path4.read_text())
94
+ if module is None:
95
+ continue
96
+
97
+ if commit is None:
98
+ print(f"Commit is None for module: {module}")
99
+ continue
100
+
101
+ log = subprocess.run(
102
+ ['git', 'log', '--oneline', str(path4)],
103
+ capture_output=True)
104
+ pr_matches = re.search(r'#([0-9]+)\)$', log.stdout.decode().splitlines()[-1])
105
+ if pr_matches:
106
+ mathlib4_pr = int(pr_matches.groups()[0])
107
+ else:
108
+ mathlib4_pr = None
109
+
110
+ data[module] = {
111
+ 'mathlib4_file': str(path4.relative_to(mathlib4_root)),
112
+ 'mathlib4_pr': mathlib4_pr,
113
+ 'source': dict(repo=repo, commit=commit)
114
+ }
115
+
116
+ graph.add_node(module)
117
+
118
+ prs = {}
119
+ fetch_args = ['git', 'fetch', 'origin']
120
+ nums = []
121
+ sync_prs = defaultdict(set)
122
+ mathlib4repo = github.Github(github_token).get_repo("leanprover-community/mathlib4")
123
+ for pr in mathlib4repo.get_pulls(state='open'):
124
+ if pr.created_at < datetime.datetime(2022, 12, 1, 0, 0, 0, tzinfo=pytz.UTC):
125
+ continue
126
+ if 'no-source-header' in (l.name for l in pr.labels):
127
+ continue
128
+ if 'mathlib3-pair' in (l.name for l in pr.labels):
129
+ for file in (f.filename for f in pr.get_files()):
130
+ sync_prs[file].add(pr.number)
131
+ num = pr.number
132
+ nums.append(num)
133
+ prs[num] = pr
134
+ fetch_args.append(f'pull/{num}/head:port-status-pull/{num}')
135
+
136
+ os.system("git branch -D $(git branch --list 'port-status-pull/*')")
137
+ subprocess.run(fetch_args)
138
+
139
+ prs_of_import = {}
140
+ for num in nums:
141
+ p = subprocess.run(
142
+ ['git', 'diff', '--name-only', '--diff-filter=A',
143
+ f'origin/master...port-status-pull/{num}'],
144
+ capture_output=True)
145
+ for l in p.stdout.decode().splitlines():
146
+ f = subprocess.run(
147
+ ['git', 'cat-file', 'blob', f'port-status-pull/{num}:{l}'],
148
+ capture_output=True)
149
+ import_, repo, commit = get_mathlib4_module_commit_info(f.stdout.decode(encoding='utf8', errors='replace'))
150
+ prs_of_import.setdefault(import_, []).append({'pr': num, 'repo': repo, 'commit': commit, 'fname': l})
151
+
152
+ COMMENTS_URL = "https://raw.githubusercontent.com/wiki/leanprover-community/mathlib4/port-comments.md"
153
+ comments_dict = yaml.safe_load(requests.get(COMMENTS_URL).content.replace(b"```", b""))
154
+
155
+ yaml_dict = {}
156
+ new_yaml_dict = {}
157
+ for node in sorted(graph.nodes):
158
+ if node in data:
159
+ new_status = dict(
160
+ ported=True,
161
+ mathlib4_file=data[node]['mathlib4_file'],
162
+ mathlib4_pr=data[node]['mathlib4_pr'],
163
+ source=data[node]['source']
164
+ )
165
+ _sync_prs = [
166
+ dict(
167
+ num=sync_pr_num,
168
+ labels=[dict(name=l.name, color=l.color) for l in prs[sync_pr_num].labels]
169
+ )
170
+ for sync_pr_num in sync_prs[data[node]['mathlib4_file']]
171
+ ]
172
+ if _sync_prs:
173
+ new_status.update(mathlib4_sync_prs=_sync_prs)
174
+ pr_status = f"mathlib4#{data[node]['mathlib4_pr']}" if data[node]['mathlib4_pr'] is not None else "_"
175
+ sha = data[node]['source']['commit'] if data[node]['source']['repo'] == 'leanprover-community/mathlib' else "_"
176
+
177
+ status = f"Yes {pr_status} {sha}"
178
+ else:
179
+ new_status = dict(ported=False)
180
+ status = f'No'
181
+ if node in prs_of_import:
182
+ pr_info = prs_of_import[node][0]
183
+ if pr_info['commit'] is None:
184
+ print('PR seems to be missing a source header', node, pr_info)
185
+ assert(False)
186
+ new_status.update(
187
+ mathlib4_pr=pr_info['pr'],
188
+ mathlib4_file=pr_info['fname'],
189
+ source=dict(repo=pr_info['repo'], commit=pr_info['commit']))
190
+ labels = [{'name': l.name, 'color': l.color} for l in prs[pr_info['pr']].labels]
191
+ if labels:
192
+ new_status.update(labels=labels)
193
+ sha = pr_info['commit'] if pr_info['repo'] == 'leanprover-community/mathlib' else "_"
194
+ status += f" mathlib4#{pr_info['pr']} {sha}"
195
+ try:
196
+ comment_data = comments_dict[node]
197
+ except KeyError:
198
+ pass
199
+ else:
200
+ if isinstance(comment_data, str):
201
+ # old comment format
202
+ comment_data = dict(message=comment_data)
203
+ # new comment format
204
+ status += ' ' + comment_data['message']
205
+ new_status.update(comment=comment_data)
206
+ yaml_dict[node] = status
207
+ new_yaml_dict[node] = new_status
208
+
209
+ DO_NOT_EDIT_MESSAGE = """
210
+ # Do not edit this file.
211
+ # If you want to add free-form comments about files that don't have PRs yet,
212
+ # edit https://github.com/leanprover-community/mathlib4/wiki/port-comments/_edit instead.
213
+ """ + ("\n" * 37)
214
+
215
+ with open('port_status.yaml', 'w') as f:
216
+ f.write(DO_NOT_EDIT_MESSAGE + "```\n" + yaml.dump(yaml_dict) + "```\n")
217
+ with open('port_status_new.yaml', 'w') as f:
218
+ f.write(DO_NOT_EDIT_MESSAGE + "```\n" + yaml.dump(new_yaml_dict) + "```\n")
repl/.lake/packages/mathlib/scripts/polyrith_sage.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is part of the `polyrith` tactic in `src/tactic/polyrith.lean`.
2
+ # It interfaces between Lean and the Sage web interface.
3
+
4
+ import requests
5
+ import json
6
+ import sys
7
+ from os.path import join, dirname
8
+
9
+ # These functions are used to format the output of Sage for parsing in Lean.
10
+ # They are stored here as a string since they are passed to Sage via the web API.
11
+ with open(join(dirname(__file__), "polyrith_sage_helper.py"), encoding='utf8') as f:
12
+ polynomial_formatting_functions = f.read()
13
+
14
+ # future extensions may change behavior depending on the base type
15
+ def type_str(type):
16
+ return "QQ"
17
+
18
+ def create_query(type: str, n_vars: int, eq_list, goal_type):
19
+ """ Create a query to invoke Sage's `MPolynomial_libsingular.lift`. See
20
+ https://github.com/sagemath/sage/blob/f8df80820dc7321dc9b18c9644c3b8315999670b/src/sage/rings/polynomial/multi_polynomial_libsingular.pyx#L4472-L4518
21
+ for a description of this method. """
22
+ var_list = [f"var{i}" for i in range(n_vars)] + ['aux']
23
+ query = f'''
24
+ if {n_vars!r} != 0:
25
+ P = PolynomialRing({type_str(type)}, {var_list})
26
+ [{", ".join(var_list)}] = P.gens()
27
+ p = P({goal_type})
28
+ gens = {eq_list} + [1 - p*aux]
29
+ I = P.ideal(gens)
30
+ coeffs = P(1).lift(I)
31
+ power = max(cf.degree(aux) for cf in coeffs)
32
+ coeffs = [P(cf.subs(aux = 1/p)*p^power) for cf in coeffs[:int(-1)]]
33
+ print(str(power)+';'+serialize_polynomials(coeffs))
34
+ else:
35
+ # workaround for a Sage shortcoming with `n_vars = 0`,
36
+ # `TypeError: no conversion of this ring to a Singular ring defined`
37
+ # In this case, there is no need to look for membership in the *radical*;
38
+ # we just check for membership in the ideal, and return exponent 1
39
+ # if coefficients are found.
40
+ P = PolynomialRing({type_str(type)}, 'var', 1)
41
+ p = P({goal_type})
42
+ I = P.ideal({eq_list})
43
+ coeffs = p.lift(I)
44
+ print('1;'+serialize_polynomials(coeffs))
45
+ '''
46
+ return query
47
+
48
+ class EvaluationError(Exception):
49
+ def __init__(self, ename, evalue, message='Error in Sage communication'):
50
+ self.ename = ename
51
+ self.evalue = evalue
52
+ self.message = message
53
+ super().__init__(self.message)
54
+
55
+ def parse_response(resp: str) -> str:
56
+ exp, data = resp.split(';', 1)
57
+ return dict(power=int(exp), coeffs=json.loads(data))
58
+
59
+
60
+ def evaluate_in_sage(query: str) -> str:
61
+ data = {'code': query}
62
+ headers = {'content-type': 'application/x-www-form-urlencoded'}
63
+ response = requests.post('https://sagecell.sagemath.org/service', data, headers=headers).json()
64
+ if response['success']:
65
+ return parse_response(response.get('stdout'))
66
+ elif 'execute_reply' in response and 'ename' in response['execute_reply'] and 'evalue' in response['execute_reply']:
67
+ raise EvaluationError(response['execute_reply']['ename'], response['execute_reply']['evalue'])
68
+ else:
69
+ raise Exception(response)
70
+
71
+ def main():
72
+ '''The system args contain the following:
73
+ 0 - the path to this python file
74
+ 1 - a string containing "true" or "false" depending on whether polyrith was called with trace enabled
75
+ 2 - a string representing the base type of the target
76
+ 3 - the number of variables used
77
+ 4 - a list of the polynomial hypotheses/proof terms in terms of the variables
78
+ 5 - a single polynomial representing the target
79
+
80
+ This returns a json object with format:
81
+ ```
82
+ { success: bool,
83
+ data: Optional[list[str]],
84
+ trace: Optional[str],
85
+ name: Optional[str],
86
+ value: Optional[str] }
87
+ ```
88
+ '''
89
+ command = create_query(sys.argv[2], int(sys.argv[3]), sys.argv[4], sys.argv[5])
90
+ final_query = polynomial_formatting_functions + "\n" + command
91
+ if sys.argv[1] == 'true': # trace dry run enabled
92
+ output = dict(success=True, trace=command)
93
+ else:
94
+ try:
95
+ output = dict(success=True, data=evaluate_in_sage(final_query))
96
+ except EvaluationError as e:
97
+ output = dict(success=False, name=e.ename, value=e.evalue)
98
+ print(json.dumps(output))
99
+
100
+ if __name__ == "__main__":
101
+ main()
repl/.lake/packages/mathlib/scripts/polyrith_sage_helper.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this file will be run by the remote sage server, so should not import local files.
2
+ from typing import Iterable
3
+
4
+ def q_arr(coeff: QQ) -> str:
5
+ return "[" + str(coeff.numerator()) + "," + str(coeff.denominator()) + "]"
6
+
7
+ def arr(args: Iterable[str]) -> str:
8
+ return "[" + ",".join(args) + "]"
9
+
10
+ def serialize_polynomials(coeffs) -> str:
11
+ return arr(
12
+ arr(arr([arr(arr([str(t[0]), str(t[1])]) for t in etuple.sparse_iter()), q_arr(coeff)])
13
+ for etuple, coeff in c.dict().items()) for c in coeffs)
repl/.lake/packages/mathlib/scripts/yaml_check.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is copied from the mathlib3 file of the same name.
3
+ It reads in the three yaml files, and translates them to simpler json files that are easier to
4
+ process in Lean.
5
+ """
6
+ from typing import Dict, Optional, Union, Tuple, List
7
+ import yaml
8
+ import json
9
+ import sys
10
+
11
+ TieredDict = Dict[str, Union[Optional[str], 'TieredDict']]
12
+
13
+ def tiered_extract(db: TieredDict) -> List[Tuple[List[str], str]]:
14
+ """From a nested dictionary, return a list of (key_path, values)
15
+ of the deepest level."""
16
+ out = []
17
+ for name, entry in db.items():
18
+ if isinstance(entry, dict):
19
+ for subname, value in tiered_extract(entry):
20
+ out.append(([name] + subname, value))
21
+ else:
22
+ if entry and '/' not in entry:
23
+ out.append(([name], entry))
24
+ return out
25
+
26
+ def flatten_names(data: List[Tuple[List[str], str]]) -> List[Tuple[str, str]]:
27
+ return [(' :: '.join(id), v) for id, v in data]
28
+
29
+ def print_list(fn: str, pairs: List[Tuple[str, str]]) -> None:
30
+ with open(fn, 'w', encoding='utf8') as out:
31
+ for (id, val) in pairs:
32
+ out.write(f'{id}\n{val.strip()}\n\n')
33
+
34
+ hundred_yaml = sys.argv[1]
35
+ overview_yaml = sys.argv[2]
36
+ undergrad_yaml = sys.argv[3]
37
+
38
+ with open(hundred_yaml, 'r', encoding='utf8') as hy:
39
+ hundred = yaml.safe_load(hy)
40
+ with open(overview_yaml, 'r', encoding='utf8') as hy:
41
+ overview = yaml.safe_load(hy)
42
+ with open(undergrad_yaml, 'r', encoding='utf8') as hy:
43
+ undergrad = yaml.safe_load(hy)
44
+
45
+ hundred_decls:List[Tuple[str, str]] = []
46
+
47
+ for index, entry in hundred.items():
48
+ title = entry['title']
49
+ if 'decl' in entry:
50
+ hundred_decls.append((f'{index} {title}', entry['decl']))
51
+ elif 'decls' in entry:
52
+ if not isinstance(entry['decls'], list):
53
+ raise ValueError(f"For key {index} ({title}): did you mean `decl` instead of `decls`?")
54
+ hundred_decls = hundred_decls + [(f'{index} {title}', d) for d in entry['decls']]
55
+
56
+ overview_decls = tiered_extract(overview)
57
+ assert all(len(n) == 3 for n, _ in overview_decls)
58
+ overview_decls = flatten_names(overview_decls)
59
+
60
+ undergrad_decls = tiered_extract(undergrad)
61
+ assert all(len(n) >= 3 for n, _ in undergrad_decls)
62
+ undergrad_decls = flatten_names(undergrad_decls)
63
+
64
+ with open('100.json', 'w', encoding='utf8') as f:
65
+ json.dump(hundred_decls, f)
66
+ with open('overview.json', 'w', encoding='utf8') as f:
67
+ json.dump(overview_decls, f)
68
+ with open('undergrad.json', 'w', encoding='utf8') as f:
69
+ json.dump(undergrad_decls, f)
repl/pass_rate.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from argparse import ArgumentParser
4
+ import json
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from tqdm import tqdm
7
+ import tempfile
8
+
9
+ def wrapped_function(item):
10
+ results = []
11
+ passed = 0
12
+ total = 0
13
+
14
+ temp_dir = tempfile.gettempdir()
15
+ temp_file = os.path.join(temp_dir, f"test.lean")
16
+
17
+ with open(temp_file, "w") as f:
18
+ f.write(item['cmd'])
19
+
20
+ # Rest of the function code...
21
+ # Process the item using the temporary file
22
+ # ...
23
+
24
+ # Clean up the temporary file
25
+ data = '{"path": "%s", "allTactics": true}' %(temp_file)
26
+ command = 'echo \'%s\' | lake exe repl' % data
27
+
28
+ try:
29
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
30
+ stdout = result.stdout.decode('utf-8')
31
+ stderr = result.stderr.decode('utf-8')
32
+ # stdout = result.stdout.decode('utf-8')
33
+ json_stdout = json.loads(stdout)
34
+ if "messages" not in json_stdout.keys():
35
+ passed += 1
36
+ # results.append({'item': item['content'], 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
37
+ results.append({ 'stdout': stdout, 'stderr': stderr, 'status': 'pass'})
38
+ except subprocess.CalledProcessError as e:
39
+ # results.append({'item': item['content'], 'error': str(e), 'status': 'nopass'})
40
+ results.append({ 'error': str(e), 'status': 'nopass'})
41
+ total += 1
42
+
43
+ pass_rate = passed / (passed + total) * 100
44
+
45
+
46
+ return {'results': results, 'pass_rate': pass_rate}
47
+
48
+ # Set the directory where your .lean files are located
49
+
50
+ # Get a list of all .lean files in the directory
51
+ # lean_files = [f for f in os.listdir(directory) if f.endswith(".lean")]
52
+ # lean_files = ["test/file.lean"]
53
+ def single(command_list):
54
+ results = []
55
+ passed = 0
56
+ total = 0
57
+ for item in tqdm(command_list):
58
+ with open("test/test.lean", "w", encoding = 'utf-8') as f:
59
+ f.write(item['cmd'])
60
+ data = '{"path": "test/test.lean", "allTactics": true}'
61
+ # data = '{"cmd": "%s", "allTactics": true}' % item['cmd']
62
+ command = 'echo \'%s\' | lake exe repl' % data
63
+ try:
64
+ # process = subprocess.Popen(['lake', 'exe', 'repl'], stdin=subprocess.PIPE, stdout=subprocess.PIPE,
65
+ # stderr=subprocess.PIPE)
66
+ # stdout, stderr = process.communicate(input=data.encode(encoding='utf-8'))
67
+ # stdout = stdout.decode('utf-8')
68
+ import pdb
69
+ pdb.set_trace()
70
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
71
+ stdout = result.stdout.decode('utf-8')
72
+ json_stdout = json.loads(stdout)
73
+ if "messages" not in json_stdout.keys():
74
+ passed += 1
75
+ stderr = result.stderr.decode('utf-8')
76
+ results.append({
77
+ # 'item': item['content'],
78
+ 'stdout': stdout,
79
+ 'stderr': stderr,
80
+ 'status': 'pass'
81
+ })
82
+ except subprocess.CalledProcessError as e:
83
+ results.append({
84
+ # 'item': item['content'],
85
+ 'error': str(e),
86
+ 'status': 'nopass'
87
+ })
88
+ total += 1
89
+
90
+ # Calculate pass rate
91
+ pass_rate = passed / total * 100
92
+ print(pass_rate)
93
+
94
+ # Save results to a JSON file
95
+ with open('results.json', 'w') as f:
96
+ json.dump({'results': results, 'pass_rate': pass_rate}, f, indent=2, ensure_ascii=False)
97
+
98
+
99
+
100
+
101
+
102
+
103
+
104
+
105
+ def multi(command_list):
106
+ results = []
107
+ passed = 0
108
+ total = 0
109
+ def execute_command(item):
110
+ temp_dir = '/hpc2hdd/home/zyang398/lujianqiao/lean4/repl/tmp'
111
+ temp_file = os.path.join(temp_dir, f"test_{item['index']}.lean") # Ensure unique filenames
112
+ with open(temp_file, "w") as f:
113
+ f.write(item['cmd'])
114
+
115
+ data = '{"path": "%s", "allTactics": true}' % temp_file
116
+ command = f'echo \'{data}\' | lake exe repl'
117
+
118
+ try:
119
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
120
+ stdout = result.stdout.decode('utf-8')
121
+ stderr = result.stderr.decode('utf-8')
122
+
123
+ if "messages" not in json.loads(stdout):
124
+ return {'stdout': stdout, 'stderr': stderr, 'status': 'pass'}
125
+ else:
126
+ return {'stdout': stdout, 'stderr': stderr, 'status': 'nopass'}
127
+
128
+ except subprocess.CalledProcessError as e:
129
+ return {'error': str(e), 'status': 'nopass'}
130
+
131
+ os.remove(temp_file)
132
+
133
+ total = len(command_list)
134
+
135
+ with ThreadPoolExecutor(max_workers=32) as executor:
136
+ futures = [executor.submit(execute_command, {'index': i, 'cmd': cmd['cmd']}) for i, cmd in enumerate(command_list)]
137
+ for future in tqdm(futures, total=total, desc="Processing Commands"):
138
+ result = future.result()
139
+ results.append(result)
140
+ if result['status'] == 'pass':
141
+ passed += 1
142
+
143
+ pass_rate = (passed / total) * 100
144
+ print(f"Pass rate: {pass_rate}%")
145
+
146
+ with open('results.json', 'w') as f:
147
+ json.dump({'results': results, 'pass_rate': pass_rate}, f, indent=2, ensure_ascii=False)
148
+
149
+ import re
150
+ def remove_simp_pattern_from_end(s):
151
+ pattern = r'@\[simp\s*.*?\]$'
152
+ return re.sub(pattern, '', s)
153
+
154
+ def main(args):
155
+ command_list = []
156
+ for i in range(args.cuda_num):
157
+ with open(f"{args.input_path}/{i}.json", 'r', encoding='utf-8') as rf:
158
+ for line in rf.readlines():
159
+ try:
160
+ json_item = json.loads(line)
161
+ # json_item['content']['statement_poof']
162
+ # json_item['cmd'] = '\n'.join([json_item['content']['working_file'] , json_item['total output'][0]])
163
+ working_env = json_item['content']['working_file']
164
+
165
+ # statement = json_item['content']['statement_poof'].split('\n')
166
+ statement = json_item['total output'][0]
167
+
168
+ json_item['cmd'] = '\n'.join([working_env, statement])
169
+ # print(json_item['cmd'])
170
+ assert len(statement) > 0
171
+ # json_item['cmd'] = '\n'.join([working_env, json_item['total output'][0]])
172
+ except:
173
+ import pdb
174
+ pdb.set_trace()
175
+ command_list.append(json_item)
176
+ command_list = command_list
177
+ results = []
178
+ passed = 0
179
+ total = 0
180
+ single(command_list)
181
+
182
+ if __name__ == '__main__':
183
+ arg_parser = ArgumentParser()
184
+ arg_parser.add_argument('--data_path', type=str,
185
+ default='data/grade-school-math-master/grade_school_math/data/test.jsonl')
186
+ arg_parser.add_argument('--input_path', type=str, default='')
187
+ arg_parser.add_argument('--cuda_num', type=int, default=4)
188
+ arg_parser.add_argument('--output_path', type=str, default='total.json')
189
+ arg_parser.add_argument('--generate_method', type=str,
190
+ choices=['single', 'sft', 'comp', 'self_consistency', 'single_consistency'])
191
+ arg_parser.add_argument('--method', type=str, choices=['main', 'test', 'get_data'])
192
+ args = arg_parser.parse_args()
193
+ main(args)
194
+
195
+
whole_generation.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+ import glob
4
+
5
+ from tqdm import tqdm
6
+ import re
7
+ import sys
8
+ import os
9
+ import numpy as np
10
+
11
+ PROMPT_DICT = {
12
+ "lean4": (
13
+ "Statement and proof in natural language:\n\n"
14
+ "{statement_text}\n\n"
15
+ "Translate the statement and proof in natural language to lean4:"
16
+ ),
17
+ "plain": (
18
+ "{statement_text}"
19
+ ),
20
+ "statement": (
21
+ "Statement in natural language:\n"
22
+ "{problem}\n"
23
+ "Translate the statement in natural language to Lean4:"
24
+ ),
25
+ "prompt_no_input": (
26
+ "Below is an instruction that describes a task. "
27
+ "Write a response that appropriately completes the request.\n\n"
28
+ "### Instruction:\n{instruction}\n\n### Response:"
29
+ ),
30
+ }
31
+
32
+
33
+ def generate_few_shot(prompt):
34
+ base_gsm8k_list = [
35
+ {
36
+ 'question': "John and his best friend Steve bought 12 cupcakes together. Each cupcake cost $1.50. If they split the costs evenly, how much did each person pay?",
37
+ 'answer': "The total cost of cupcakes was 1.5*12=$<<1.5*12=18>>18\\nSo they each paid 18/2=$<<18/2=9>>9.",
38
+ 'direct_answer': "9"
39
+ },
40
+ {
41
+ 'question': "Lizzy has to ship 540 pounds of fish that are packed into 30-pound crates. If the shipping cost of each crate is $1.5, how much will Lizzy pay for the shipment?",
42
+ 'answer': "There are 540 pounds / 30 pounds/crate = <<540/30=18>>18 crates of fish needed.\\nHence, the total cost for the shipment is $1.5/crate x 18 crates = $<<1.5*18=27>>27.",
43
+ 'direct_answer': "27"
44
+ },
45
+ {
46
+ 'question': "Tom, Tim, and Paul are collecting photos of cars. Paul has 10 photos more than Tim. Tim has one hundred photos less than the total amount of photos which is 152. How many photos does Tom have?",
47
+ 'answer': "Tim has 152 photos - 100 photos = <<152-100=52>>52 photos.\\nWhen Tim has 52 photos, then Paul has 52 + 10 photos = <<52+10=62>>62 photos.\\nTim and Paul have together 52 photos + 62 photos = <<52+62=114>>114 photos.\\nThat leaves Tom with 152 photos - 114 photos = <<152-114=38>>38 photos.",
48
+ 'direct_answer': "38"
49
+ },
50
+
51
+ ]
52
+ index_list = list(range(len(base_gsm8k_list)))
53
+ random.shuffle(index_list)
54
+ few_shot_example = ""
55
+ for i in index_list:
56
+ item = base_gsm8k_list[i]
57
+ few_shot_example += "Q: " + item['question'] + "\n" + "A: "+ item['answer'] + "\nThe answer is " + item['direct_answer'] + "\n"
58
+
59
+ few_shot_example += "Q: " + prompt + "A: "
60
+ return few_shot_example
61
+
62
+
63
+ def generate_prompt_translate(args, question):
64
+ return PROMPT_DICT['statement'].format(problem= question)
65
+
66
+ def generate_prompt_solver(args, question):
67
+ return PROMPT_DICT['plain'].format(statement_text= question)
68
+
69
+
70
+ def generate_prompt_generation(args, question):
71
+ if args.evaluation_mode == 'generation':
72
+ if args.method == 'zero_shot_cot':
73
+ content = question + " Let's think step by step."
74
+ elif args.method == 'zero_shot':
75
+ content = question
76
+ elif args.method == 'few_shot':
77
+ content = generate_few_shot(question)
78
+ else:
79
+ raise ValueError("we do not method for such model type yet")
80
+
81
+ if "generator" not in args.model_type:
82
+ MODEL_DICT = {
83
+ "llama": (
84
+ "[INST] \n{content}\n [/INST]"
85
+ ),
86
+ "mistral": (
87
+ "<s>[INST] {content} [/INST]"
88
+ ),
89
+ "chatglm": (
90
+ "<|user|> \n{content}\n <|assistant|>"
91
+ ),
92
+ "qianwen": (
93
+ "<|im_start|>user\n{content}<|im_end|>\n<|im_start|>assistant\n"
94
+ ),
95
+ "baichuan": (
96
+ "<reserved_106>{content}<reserved_107>"
97
+ )
98
+ }
99
+
100
+ if args.model_type in ["qianwen", "qianwen-13b", "qianwen-70b"]:
101
+ content = MODEL_DICT['qianwen'].format_map(
102
+ {'content': content}
103
+ )
104
+
105
+ elif args.model_type in ["chatglm"]:
106
+ pass
107
+
108
+
109
+ elif args.model_type in ['llama2-7b-chat']:
110
+ content = MODEL_DICT['llama'].format_map(
111
+ {'content': content}
112
+ )
113
+
114
+ elif args.model_type in ["mistral", 'mixtral']:
115
+ content = MODEL_DICT['mistral'].format_map(
116
+ {'content': content}
117
+ )
118
+
119
+
120
+ return content
121
+
122
+
123
+
124
+
125
+ few_shot_list = [
126
+ {
127
+ 'question': "There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?",
128
+ 'answer': "There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6.",
129
+ 'direct_answer': "6"
130
+ },
131
+ {
132
+ 'question': "If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?",
133
+ 'answer': "There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5.",
134
+ 'direct_answer': "5",
135
+ },
136
+ {
137
+ 'question': "Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?",
138
+ 'answer': "Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39.",
139
+ 'direct_answer': "39",
140
+ },
141
+ {
142
+ 'question': "Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?",
143
+ 'answer': "Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8.",
144
+ 'direct_answer': "8",
145
+ },
146
+ {
147
+ 'question': "Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?",
148
+ 'answer': "Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9.",
149
+ 'direct_answer': "9",
150
+ },
151
+ {
152
+ 'question': "There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?",
153
+ 'answer': "There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29.",
154
+ 'direct_answer': "29",
155
+ },
156
+ {
157
+ 'question': "Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?",
158
+ 'answer': "Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls.",
159
+ 'direct_answer': "33",
160
+ },
161
+ {
162
+ 'question': "Olivia has $23. She bought five bagels for $3 each. How much money does she have left?",
163
+ 'answer': "Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8.",
164
+ 'direct_answer': "8",
165
+ },
166
+ ]
167
+ import json
168
+
169
+ from collections import Counter
170
+
171
+
172
+ def self_consistency(pairs):
173
+ val_counts = Counter(value for key, value in pairs)
174
+ most = val_counts.most_common(1)[0][0]
175
+ for key, value in pairs:
176
+ if value == most:
177
+ return key
178
+
179
+
180
+ #
181
+ def find_feedback(content):
182
+ match = re.search(r'Judgement: (.+)', content)
183
+ if match:
184
+ judgement = match.group(1)
185
+ else:
186
+ judgement = "None"
187
+ return judgement
188
+
189
+
190
+ def str2bool(s):
191
+ s = s.lower()
192
+ if s == 'true':
193
+ return True
194
+ elif s == 'false':
195
+ return False
196
+ else:
197
+ raise ValueError('invalid value: {}, must be true or false'.format(s))
198
+
199
+
200
+ def parse_arguments():
201
+ parser = argparse.ArgumentParser(description="Zero-shot-CoT")
202
+
203
+ # parser.add_argument(
204
+ # "--dataset", type=str, default="plan",
205
+ # choices=["plan", 'tool_use_awareness', 'tool_selection', 'tool_selection_harder', 'tool_creation_awareness',
206
+ # 'tool_creation_awareness_harder', 'tool_creation',
207
+ # 'arguments_filling'], help="dataset used for experiment")
208
+ parser.add_argument(
209
+ "--cot_trigger_no", type=int, default=1,
210
+ help="A trigger sentence that elicits a model to execute chain of thought"
211
+ )
212
+ parser.add_argument("--dataset", type=str, default="")
213
+ parser.add_argument("--data_path", type=str, default="")
214
+ parser.add_argument("--evaluation_mode", type=str, default="")
215
+ parser.add_argument("--batch_size", type=int, default=1)
216
+ parser.add_argument("--eval_method", type=str, default="")
217
+
218
+ parser.add_argument("--model_path", type=str, default="")
219
+
220
+ parser.add_argument("--model_type", type=str, default="chatglm")
221
+
222
+ parser.add_argument("--output_dir", type=str, default="generation_test")
223
+
224
+ parser.add_argument("--lora_path", type=str, default="")
225
+
226
+ parser.add_argument("--iter_num", type=int, default=1)
227
+ parser.add_argument("--method", type=str, default="few_shot_cot")
228
+ parser.add_argument("--data_question_key", type=str, default="question")
229
+ parser.add_argument("--data_answer_key", type=str, default="answer")
230
+
231
+ parser.add_argument("--sample_num", type=int, default=1)
232
+
233
+ parser.add_argument("--cuda_ind", type=int, default=0)
234
+ parser.add_argument("--tensor_parallel", type=int, default=1)
235
+ parser.add_argument("--cuda_start", type=int, default=0)
236
+ parser.add_argument("--cuda_num", type=int, default=8)
237
+
238
+ parser.add_argument("--load_in_8bit", type=str2bool, default=False)
239
+ parser.add_argument("--rewrite", type=str2bool, default=True)
240
+
241
+ parser.add_argument("--use_typewriter", type=int, default=0)
242
+
243
+ parser.add_argument("--temperature", type=float, default=0.0)
244
+ parser.add_argument("--top_p", type=float, default=1)
245
+ parser.add_argument("--iter_max_new_tokens", type=int, default=512)
246
+ parser.add_argument("--init_max_new_tokens", type=int, default=2048)
247
+ parser.add_argument("--min_new_tokens", type=int, default=1)
248
+ parser.add_argument("--correct_response_format", type=str, default="The correct response is:")
249
+
250
+ args = parser.parse_args()
251
+ if args.evaluation_mode == 'generation':
252
+ if "lean" in args.dataset:
253
+ args.data_question_key = 'model_response'
254
+ args.data_answer_key = 'statement_poof'
255
+
256
+ if args.dataset == "lean4_5k_test":
257
+ args.data_path = "data/lean4_gpt_5k/test/data.jsonl"
258
+ elif args.dataset == "lean4_basic_test":
259
+ args.data_path = "data/lean4_basic/1k_test.jsonl"
260
+ elif args.dataset == "lean4_random_test":
261
+ args.data_path = "data/lean4_random/1k_test.json"
262
+ elif args.dataset == "lean4_random_first_train":
263
+ args.data_path = "data/lean4_random/5k_first.json"
264
+ elif args.dataset == "lean4_random_second_train":
265
+ args.data_path = "data/lean4_random/5k_second.json"
266
+ elif args.dataset == "lean4_random_third_train":
267
+ args.data_path = "data/lean4_random/5k_third.json"
268
+
269
+ if args.model_type == 'mistral_generator':
270
+ args.model_path = 'models/gsm8k/generators/mistral-ep2/'
271
+ elif args.model_type == 'mistral_generator_original':
272
+ args.model_path = '/data/OVM-Mistral-7b/mistral7b-ep2/'
273
+ elif args.model_type == 'gemma_generator':
274
+ args.model_path = 'models/gsm8k/generators/gemma2b2-ep2/'
275
+ elif args.model_type == 'phi2_generator':
276
+ args.model_path = 'models/gsm8k/generators/phi2b-ep2/'
277
+
278
+ elif args.model_type == 'mixtral':
279
+ args.model_path = '/data/Mixtral-8x7B-Instruct-v0.1'
280
+
281
+ elif args.model_type == 'mistral':
282
+ args.model_path = '/data/mistral-instruct'
283
+
284
+ elif args.model_type == 'qianwen-70b':
285
+ args.model_path = '/data/Qwen-72B-Chat'
286
+
287
+
288
+ elif args.model_type == 'llama2-7b-chat':
289
+ args.model_path = '/data/Llama-2-7b-chat/'
290
+
291
+ if args.cot_trigger_no == 1:
292
+ args.cot_trigger = "Let's think step by step."
293
+
294
+ return args
295
+
296
+
297
+ def create_demo_text(args, cot_flag, index_list):
298
+ # Concatenate demonstration examples ...
299
+ demo_text = ""
300
+ for i in index_list:
301
+ item = few_shot_list[i]
302
+ if cot_flag:
303
+ demo_text += "Q: " + item['question'] + "\nA: " + item['answer'] + " " + \
304
+ args.direct_answer_trigger_for_fewshot + " " + item['direct_answer'] + ".\n\n"
305
+ else:
306
+ demo_text += "Q: " + item['question'] + "\nA: " + \
307
+ args.direct_answer_trigger_for_fewshot + " " + item['direct_answer'] + ".\n\n"
308
+
309
+ return demo_text
310
+
311
+
312
+ def str2bool(s):
313
+ s = s.lower()
314
+ if s == 'true':
315
+ return True
316
+ elif s == 'false':
317
+ return False
318
+ else:
319
+ raise ValueError('invalid value: {}, must be true or false'.format(s))
320
+
321
+
322
+ def batchify(pairs, batch_size):
323
+
324
+ """将列表分成指定大小的批次"""
325
+ for i in range(0, len(pairs), batch_size):
326
+ yield pairs[i:i + batch_size]
327
+
328
+
329
+ def generate_prompts(questions, func, args):
330
+ """为每个问题生成提示"""
331
+ prompts = [func(args, question) for question in questions]
332
+ return prompts
333
+
334
+
335
+ def get_question_answer(args):
336
+ allfilepath = args.data_path
337
+ questions = []
338
+ answers = []
339
+
340
+ # Attempt to read the file as a regular JSON file
341
+ for filepath in allfilepath.split(','):
342
+ try:
343
+ with open(filepath, 'r') as file:
344
+ data = json.load(file)
345
+ # If the data is a list, assume it's an array of objects
346
+ if isinstance(data, list):
347
+ for json_item in data:
348
+ questions.append(json_item[args.data_question_key])
349
+ answers.append(json_item)
350
+ # If the data is a dict, assume it's a single object (or adjust logic as needed)
351
+ elif isinstance(data, dict):
352
+ questions.append(data[args.data_question_key])
353
+ answers.append(json_item)
354
+
355
+ except ValueError:
356
+ # If it fails, assume the file is in JSON Lines format
357
+ with open(filepath, 'r') as file:
358
+ for line in file:
359
+ json_item = json.loads(line)
360
+ questions.append(json_item[args.data_question_key])
361
+ answers.append(json_item)
362
+
363
+ questions = [ PROMPT_DICT['lean4'].format(statement_text = item) for item in questions]
364
+
365
+ return questions, answers
366
+
367
+
368
+ def main3(args):
369
+ from vllm import LLM, SamplingParams
370
+ import torch
371
+
372
+
373
+
374
+ print("load data")
375
+
376
+
377
+
378
+ questions, answers = get_question_answer(args)
379
+
380
+
381
+
382
+ question_exist_list = []
383
+ write_pattern = 'w' if args.rewrite else "a+"
384
+ if os.path.exists(args.output_dir) and not args.rewrite :
385
+ # 如果文件存在,从文件中读取数据加载到response_list
386
+ # Loop through each file that matches the pattern
387
+ file_pattern = os.path.join(args.output_dir, '[0-9]*.json')
388
+ for file_path in glob.glob(file_pattern):
389
+ # Open and read the JSON file
390
+ with open(file_path, 'r') as fp:
391
+ # Extract the 'question' field from each line and add it to the list
392
+ for line in fp.readlines():
393
+ question_exist_list.append(json.loads(line)['question'])
394
+ else:
395
+ try:
396
+ os.mkdir(args.output_dir)
397
+ except:
398
+ pass
399
+ qa_pairs = [(questions[idx], answers[idx]) for idx in range(len(questions)) if questions[idx] not in question_exist_list ]
400
+ cuda_pieces = np.array_split(range(len(qa_pairs)), args.cuda_num // args.tensor_parallel)
401
+ print(f"fitered {len(questions) - len(qa_pairs)} already")
402
+
403
+ with open(f"{args.output_dir}/{args.cuda_ind // args.tensor_parallel + args.cuda_start}.json", write_pattern,
404
+ encoding='utf-8') as wf:
405
+ start = cuda_pieces[args.cuda_start + args.cuda_ind // args.tensor_parallel][0]
406
+ end = cuda_pieces[args.cuda_start + args.cuda_ind // args.tensor_parallel][-1] + 1
407
+ subset_length = end - start
408
+ total_batches = (subset_length + args.batch_size - 1) // args.batch_size # Calculate the total number of batches
409
+ for batch in tqdm(batchify(qa_pairs[start:end], args.batch_size), total=total_batches):
410
+
411
+
412
+ questions, answers = zip(*batch) # 解压问题和答案
413
+ with torch.no_grad():
414
+ model = LLM(model=args.translate_model_path, dtype="bfloat16", trust_remote_code=True,
415
+ tensor_parallel_size=args.tensor_parallel, gpu_memory_utilization=0.95)
416
+
417
+ translate_prompts = generate_prompts(questions, generate_prompt_translate, args)
418
+ translate_output_all = []
419
+ try:
420
+ for i in range(args.sample_num):
421
+ sample_list = []
422
+ sampling_params = SamplingParams(temperature=args.temperature, top_p=args.top_p,
423
+ max_tokens=args.init_max_new_tokens)
424
+ generations = model.generate(translate_prompts, sampling_params, use_tqdm=False)
425
+ for generation_output in generations:
426
+ output = generation_output.outputs[0].text
427
+ sample_list.append(output)
428
+ translate_output_all.append(sample_list)
429
+
430
+ translate_output_all = list(map(list, zip(*translate_output_all)))
431
+ except Exception as e:
432
+ print(str(e))
433
+ exit
434
+
435
+ for batch in tqdm(batchify(qa_pairs[start:end], args.batch_size), total=total_batches):
436
+ questions, answers = zip(*batch) # 解压问题和答案
437
+ with torch.no_grad():
438
+
439
+ model = LLM(model=args.solver_model_path, dtype="bfloat16", trust_remote_code=True,
440
+ tensor_parallel_size=args.tensor_parallel, gpu_memory_utilization=0.95)
441
+ solver_prompts = generate_prompts(translate_output_all, generate_prompt_solver ,args)
442
+ solver_output_all = []
443
+
444
+ try:
445
+ for i in range(args.sample_num):
446
+ solver_sample_list = []
447
+ sampling_params = SamplingParams(temperature=args.temperature, top_p=args.top_p,
448
+ max_tokens=args.init_max_new_tokens)
449
+ generations = model.generate(solver_prompts, sampling_params, use_tqdm=False)
450
+ for generation_output in generations:
451
+ output = generation_output.outputs[0].text
452
+ solver_sample_list.append(output)
453
+ translate_output_all.append(solver_sample_list)
454
+
455
+ translate_output_all= list(map(list, zip(*translate_output_all)))
456
+
457
+ except Exception as e:
458
+ print(str(e))
459
+ exit
460
+ dicts = []
461
+
462
+ for question, answer,translate_output, translate_prompt, solver_output, solver_prompt in zip(questions, answers, translate_output_all, translate_prompts, solver_output_all, solver_prompts):
463
+ dicts.append({
464
+ "question": question,
465
+ "translate output": translate_output,
466
+ "translate prompt": translate_prompt,
467
+ "soler output": solver_output,
468
+ "soler prompt": solver_prompt,
469
+ "answer": answer,
470
+ })
471
+
472
+ for dict in dicts:
473
+ wf.writelines(json.dumps(dict, ensure_ascii=False) + '\n')
474
+
475
+ wf.flush()
476
+
477
+
478
+ def main(argv=None):
479
+ args = parse_arguments()
480
+ print('*****************************')
481
+ print(args)
482
+ print('*****************************')
483
+ if args.evaluation_mode == 'generation':
484
+ main3(args)
485
+ else:
486
+ raise ValueError("we do not yet inplement")
487
+
488
+
489
+ if __name__ == "__main__":
490
+ main()
491
+