DeepLearning101 commited on
Commit
2fcdf98
·
1 Parent(s): 6e66a6e

Delete models/tools

Browse files
models/tools/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # @Time : 2021/12/2 5:41 p.m.
3
- # @Author : JianingWang
4
- # @File : __init__.py
 
 
 
 
 
models/tools/analysis_toolkits/__init__.py DELETED
File without changes
models/tools/computations/softmax.py DELETED
@@ -1,8 +0,0 @@
1
- import torch
2
-
3
- """
4
- Transform the torch logits into probabilities.
5
- """
6
- def softmax(logits):
7
- probs = torch.softmax(torch.from_numpy(logits).float(), -1).numpy()
8
- return probs
 
 
 
 
 
 
 
 
 
models/tools/data_structures/__init__.py DELETED
File without changes
models/tools/data_structures/trie.py DELETED
@@ -1,152 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # @Time : 2022/2/15 7:57 下午
3
- # @Author : JianingWang
4
- # @File : trie
5
- import logging
6
- from typing import List
7
- from collections import OrderedDict
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
-
12
- class Trie:
13
- def __init__(self):
14
- self.data = {}
15
-
16
- def add(self, word: str):
17
- """
18
- Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation.
19
- The special key `""` is used to represent termination.
20
-
21
- This function is idempotent, adding twice the same word will leave the trie unchanged
22
-
23
- Example:
24
-
25
- ```python
26
- >>> trie = Trie()
27
- >>> trie.add("Hello 友達")
28
- >>> trie.data
29
- {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}}
30
-
31
- >>> trie.add("Hello")
32
- >>> trie.data
33
- {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}}
34
- ```
35
- """
36
- if not word:
37
- # Prevent empty string
38
- return
39
- ref = self.data
40
- for char in word:
41
- ref[char] = char in ref and ref[char] or {}
42
- ref = ref[char]
43
- ref[""] = 1
44
-
45
- def find(self, text: str):
46
- states = OrderedDict()
47
- offsets = []
48
- skip = 0
49
- for current, current_char in enumerate(text):
50
- if skip and current < skip:
51
- continue
52
- to_remove = set()
53
- reset = False
54
- for start, trie_pointer in states.items():
55
- if "" in trie_pointer:
56
- for lookstart, looktrie_pointer in states.items():
57
- if lookstart > start:
58
- break
59
- elif lookstart < start:
60
- lookahead_index = current + 1
61
- end = current + 1
62
- else:
63
- lookahead_index = current
64
- end = current
65
- next_char = text[lookahead_index] if lookahead_index < len(text) else None
66
- if "" in looktrie_pointer:
67
- start = lookstart
68
- end = lookahead_index
69
- skip = lookahead_index
70
-
71
- while next_char in looktrie_pointer:
72
- looktrie_pointer = looktrie_pointer[next_char]
73
- lookahead_index += 1
74
- if "" in looktrie_pointer:
75
- start = lookstart
76
- end = lookahead_index
77
- skip = lookahead_index
78
-
79
- if lookahead_index == len(text):
80
- break
81
- next_char = text[lookahead_index]
82
- offsets.append([start, end])
83
- reset = True
84
- break
85
- elif current_char in trie_pointer:
86
- trie_pointer = trie_pointer[current_char]
87
- states[start] = trie_pointer
88
- else:
89
- to_remove.add(start)
90
- if reset:
91
- states = {}
92
- else:
93
- for start in to_remove:
94
- del states[start]
95
- if current >= skip and current_char in self.data:
96
- states[current] = self.data[current_char]
97
- for start, trie_pointer in states.items():
98
- if "" in trie_pointer:
99
- end = len(text)
100
- offsets.append([start, end])
101
- break
102
-
103
- return offsets
104
-
105
- def split(self, text: str) -> List[str]:
106
- """
107
- Example:
108
-
109
- ```python
110
- >>> trie = Trie()
111
- >>> trie.split("[CLS] This is a extra_id_100")
112
- ["[CLS] This is a extra_id_100"]
113
-
114
- >>> trie.add("[CLS]")
115
- >>> trie.add("extra_id_1")
116
- >>> trie.add("extra_id_100")
117
- >>> trie.split("[CLS] This is a extra_id_100")
118
- ["[CLS]", " This is a ", "extra_id_100"]
119
- ```
120
- """
121
- word_sets = self.find(text)
122
- offsets = [0]
123
- for w in word_sets:
124
- offsets.extend(w)
125
- return self.cut_text(text, offsets)
126
-
127
- def cut_text(self, text, offsets):
128
- offsets.append(len(text))
129
- tokens = []
130
- start = 0
131
- for end in offsets:
132
- if start > end:
133
- logger.error(
134
- "There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it anyway."
135
- )
136
- continue
137
- elif start == end:
138
- continue
139
- tokens.append(text[start:end])
140
- start = end
141
-
142
- return tokens
143
-
144
- def __reduce__(self):
145
- return None
146
-
147
-
148
- if __name__ == "__main__":
149
- trie = Trie()
150
- for word in ["A", "AB", "BD", "BWA"]:
151
- trie.add(word)
152
- print(trie.__reduce__())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tools/model_utils/__init__.py DELETED
File without changes
models/tools/model_utils/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (139 Bytes)
 
models/tools/model_utils/__pycache__/parameter_freeze.cpython-38.pyc DELETED
Binary file (2.8 kB)
 
models/tools/model_utils/calibrate.py DELETED
@@ -1,202 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # @Time    : 2023/3/20 8:02 p.m.
3
- # @Author  : Jianing Wang
4
- # @File    : calibrate.py
5
-
6
- import os
7
- import numpy as np
8
- import torch
9
-
10
- """
11
- Use LM to classify label words for calibrating CLS
12
- """
13
- class CLSCalibrator:
14
- pass
15
-
16
- """
17
- Use Causal LM to generate label words for calibrating CLS
18
- e.g., use gpt2 to generate a label word with in-context prompts, and calibrate for the prediction.
19
- Paper: http://proceedings.mlr.press/v139/zhao21c.html
20
- """
21
- class CausalCLSCalibrator:
22
-
23
- def __init__(self, model, tokenizer) -> None:
24
- self.model = model
25
- self.tokenizer = tokenizer
26
-
27
- def calibrate(self, all_label_probs, content_free_examples, label2id, mode="diagonal_W"):
28
- """Perform calibration for de-biasing and obtain calibrated probability"""
29
- p_cf = self.get_content_free_prediction(content_free_examples, label2id)
30
-
31
- num_classes = all_label_probs.shape[1]
32
- if p_cf is None:
33
- # do not calibrate
34
- W = np.identity(num_classes)
35
- b = np.zeros([num_classes, 1])
36
- else:
37
- # calibrate
38
- if mode == "diagonal_W":
39
- W = np.linalg.inv(np.identity(num_classes) * p_cf)
40
- b = np.zeros([num_classes, 1])
41
- elif mode == "identity_W":
42
- W = np.identity(num_classes)
43
- b = -1 * np.expand_dims(p_cf, axis=-1)
44
- else:
45
- assert False
46
-
47
-
48
- all_calibrate_label_probs = list()
49
- for label_probs in all_label_probs:
50
- label_probs = label_probs / np.sum(label_probs) # normalize to 1
51
- calibrate_label_probs = np.matmul(W, np.expand_dims(label_probs, axis=-1)) + b
52
- all_calibrate_label_probs.append(calibrate_label_probs.squeeze().tolist())
53
- return np.array(all_calibrate_label_probs)
54
-
55
-
56
- def get_content_free_prediction(self, content_free_examples, label2id: dict):
57
- """Query model with content free input, return its prediction probability for each label"""
58
-
59
- all_p_y = []
60
- for content_free_example in content_free_examples:
61
-
62
- content_free_prompt = content_free_example["content_free_prompt"]
63
- p_y = [0] * len(label2id)
64
- for answers, i in label2id.items():
65
- prob = 0
66
- for a in answers:
67
- prob += np.exp(self.get_causal_cls_prediction(content_free_prompt + " " + a, 0, echo=True, num_log_probs=1)['choices'][0]['logprobs']['token_logprobs'][-1])
68
- p_y[i] = prob
69
- all_p_y.append(p_y)
70
-
71
- p_y = np.mean(np.array(all_p_y), axis=0)
72
- p_y = p_y / np.sum(p_y) # normalize
73
- return p_y
74
-
75
-
76
- def get_causal_cls_prediction(self, prompt, l=10, num_log_probs=None, echo=False):
77
- ''' This function runs GPT-2 locally but places the outputs into an json that looks just like the one
78
- provided by the OpenAI API. '''
79
- if isinstance(prompt, str):
80
- prompt = [prompt] # the code below assumes a list
81
- input_ids = self.tokenizer.batch_encode_plus(prompt, return_tensors="pt", padding=True)
82
-
83
- if l + len(input_ids['input_ids'][0]) > 1020:
84
- m = l + len(input_ids['input_ids'][0]) - 1024
85
- input_ids['input_ids'] = torch.Tensor([input_ids['input_ids'][0][m:].numpy()]).long()
86
- input_ids['attention_mask'] = torch.Tensor([input_ids['attention_mask'][0][m:].numpy()]).long()
87
-
88
- # greedily generate l tokens
89
- # print("l=", l)
90
- if l > 0:
91
- # the generate function can handle left padded inputs automatically in HF
92
- # total_sequences is now the input + possible generated output
93
- # print("l + len(input_ids[input_ids][0]=", l + len(input_ids['input_ids'][0]))
94
- total_sequences = self.model.generate(
95
- input_ids=input_ids['input_ids'].to(self.model.device),
96
- attention_mask=input_ids['attention_mask'].to(self.model.device),
97
- max_length=l + len(input_ids['input_ids'][0]),
98
- do_sample=False
99
- )
100
- else:
101
- assert echo == True and l == 0
102
- total_sequences = input_ids['input_ids'].to(self.model.device)
103
- # print("="*50)
104
- # print("total_sequences=", total_sequences) [batch, len+l]
105
- # print("total_sequences.shape=", total_sequences.shape)
106
-
107
- # they want the probs of the top tokens
108
- if num_log_probs is not None:
109
- # we are left padding, so we need to adjust the position IDs
110
- attention_mask = (total_sequences != 50256).float()
111
- position_ids = attention_mask.long().cumsum(-1) - 1
112
- position_ids.masked_fill_(attention_mask == 0, 1)
113
- # get the logits for the context and the next l tokens
114
- logits = self.model.forward(input_ids=total_sequences, attention_mask=attention_mask, position_ids=position_ids, return_dict=True).logits.detach().cpu()
115
- if not echo:
116
- # get the top tokens and probs for the generated l tokens
117
- probs = torch.softmax(logits[:,-l-1:], dim=2).cpu()
118
- else:
119
- # get the top tokens and probs for the context and the generated l tokens
120
- probs = torch.softmax(logits, dim=2).cpu()
121
- top_probs, top_tokens = torch.topk(probs, k=num_log_probs)
122
- logprobs = torch.log(probs)
123
- top_log_probs = torch.log(top_probs)
124
- # print("top_log_probs=", top_log_probs)
125
- # print("top_log_probs.shape=", top_log_probs.shape) # [1, 2, 100] [batch, 2, api_num_log_prob]
126
-
127
- # create the return value to resemble OpenAI
128
- return_json = {}
129
- choices = []
130
- # print("="*50)
131
- for batch_id in range(len(prompt)):
132
- curr_json = {}
133
- # text is just the optional context and next l tokens
134
- if not echo:
135
- curr_json['text'] = self.tokenizer.decode(total_sequences[batch_id][-l:], skip_special_tokens=True)
136
- else:
137
- curr_json['text'] = self.tokenizer.decode(total_sequences[batch_id], skip_special_tokens=True)
138
-
139
- # fill the return json with the top tokens and probs to match the OpenAI return value.
140
- if num_log_probs is not None:
141
- curr_json['logprobs'] = {}
142
- curr_json['logprobs']['top_logprobs'] = []
143
- curr_json['logprobs']['token_logprobs'] = []
144
- curr_json['logprobs']['tokens'] = []
145
- if not echo:
146
- # cutoff the -1 here because the probs are shifted one over for LMs
147
- for current_element_top_log_probs, current_element_top_tokens in zip(top_log_probs[batch_id][:-1], top_tokens[batch_id][:-1]):
148
- # tokens is a list of the top token at each position
149
- curr_json['logprobs']['tokens'].append(self.tokenizer.decode([current_element_top_tokens[0]]))
150
- # token_logprobs is a list of the logprob of the top token at each position
151
- curr_json['logprobs']['token_logprobs'].append(current_element_top_log_probs[0].item())
152
- # top_logprobs is a list of dicts for the top K tokens. with each entry being {'token_name': log_prob}
153
- temp = {}
154
- for log_prob, token in zip(current_element_top_log_probs, current_element_top_tokens):
155
- temp[self.tokenizer.decode(token.item())] = log_prob.item()
156
- curr_json['logprobs']['top_logprobs'].append(temp)
157
- else:
158
- # same as not above but small tweaks
159
- # we add null to the front because for the GPT models, they have null probability for the first token
160
- # (for some reason they don't have an beginning of sentence token)
161
- curr_json['logprobs']['top_logprobs'].append('null')
162
- # cutoff the -1 here because the probs are shifted one over for LMs
163
- for index, (current_element_top_log_probs, current_element_top_tokens) in enumerate(zip(top_log_probs[batch_id][:-1], top_tokens[batch_id][:-1])):
164
- # skip padding tokens
165
- if total_sequences[batch_id][index].item() == 50256:
166
- continue
167
- temp = {}
168
- for log_prob, token in zip(current_element_top_log_probs, current_element_top_tokens):
169
- temp[self.tokenizer.decode(token.item())] = log_prob.item()
170
- curr_json['logprobs']['top_logprobs'].append(temp)
171
- for index in range(len(probs[batch_id])):
172
- curr_json['logprobs']['tokens'].append(self.tokenizer.decode([total_sequences[batch_id][index]]))
173
- curr_json['logprobs']['token_logprobs'].append('null')
174
- for index, log_probs_token_position_j in enumerate(logprobs[batch_id][:-1]):
175
- # probs are left shifted for LMs
176
- curr_json['logprobs']['token_logprobs'].append(log_probs_token_position_j[total_sequences[batch_id][index+1]])
177
-
178
- choices.append(curr_json)
179
- # print("curr_json=", curr_json)
180
- '''
181
- e.g.,
182
- num_tokens_to_predict=1
183
- curr_json= {
184
- 'text': ' I', # 当前生成的top词
185
- 'logprobs': {'top_logprobs': [{' I': -3.4267239570617676, '\n': -3.5073862075805664, ...], # top100词及其socre
186
- 'token_logprobs': [-3.4267239570617676], # 当前top词的score
187
- 'tokens': [' I']}
188
- }
189
- num_tokens_to_predict=2
190
- curr_json= {
191
- 'text': '\nThe', # 如果指定生成两个词,则为两个词
192
- 'logprobs': {'top_logprobs': [ # 两个位置对应的预测的score
193
- {'\n': -3.186706304550171, '\xa0': -3.222092390060425, ' We': -6.781067848205566, ...},
194
- {'The': -2.5251243114471436, '"': -2.857935667037964, ...],
195
- 'token_logprobs': [-3.186706304550171, -2.5251243114471436], # 生成的词的score
196
- 'tokens': ['\n', 'The']}
197
- }
198
- '''
199
- return_json['choices'] = choices
200
- # print("="*50)
201
- # print("return_json=", return_json)
202
- return return_json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tools/model_utils/gpt_response.py DELETED
@@ -1,138 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # @Time    : 2023/3/23 1:02 p.m.
3
- # @Author  : Jianing Wang
4
- # @File    : gpt_response.py
5
-
6
- import os
7
- import sys
8
- import torch
9
- import openai
10
- import time
11
-
12
- """
13
- Call for GPT-style LLM.
14
- The output format is the same as OpenAI (e.g., GPT-3.5 text-davinci-003)
15
- """
16
- class GPTResponse:
17
-
18
- def __init__(self, model_type: str, data_path: str) -> None:
19
- assert model_type in ["gpt2", "gpt3"]
20
- self.model_type = model_type
21
- if self.model_type == "gpt3":
22
-
23
- with open(os.path.join(data_path, 'openai_key.txt'), 'r') as f:
24
- key = f.readline().strip()
25
- openai.api_key = key
26
-
27
- def call_for_gpt3_response(self, prompt, l, model_name, temp=0, num_log_probs=None, echo=False, n=None):
28
- """
29
- call GPT-3 API until result is provided and then return it
30
- """
31
- response = None
32
- received = False
33
- while not received:
34
- try:
35
- response = openai.Completion.create(engine=model_name, prompt=prompt, max_tokens=l, temperature=temp,
36
- logprobs=num_log_probs, echo=echo, stop='\n', n=n)
37
- received = True
38
- except:
39
- error = sys.exc_info()[0]
40
- if error == openai.error.InvalidRequestError: # something is wrong: e.g. prompt too long
41
- print(f"InvalidRequestError\nPrompt passed in:\n\n{prompt}\n\n")
42
- assert False
43
-
44
- print("API error:", error)
45
- time.sleep(1)
46
- return response
47
-
48
- def call_for_gpt2_response(self, gpt2_tokenizer, logits, total_sequences, l=10, num_log_probs=None, echo=False, n=None):
49
- """
50
- Obtain the prediction logits from gpt2 in local, and convert it to the value that can match the response from OpenAI
51
- """
52
- if not echo:
53
- # get the top tokens and probs for the generated l tokens
54
- probs = torch.softmax(logits[:,-l-1:], dim=2).cpu()
55
- else:
56
- # get the top tokens and probs for the context and the generated l tokens
57
- probs = torch.softmax(logits, dim=2).cpu()
58
- # print("probs=", probs)
59
- top_probs, top_tokens = torch.topk(probs, k=num_log_probs)
60
- logprobs = torch.log(probs)
61
- top_log_probs = torch.log(top_probs)
62
-
63
- # create the return value to resemble OpenAI
64
- return_json = {}
65
- choices = []
66
- # print("="*50)
67
- for batch_id in range(len(logits)):
68
- curr_json = {}
69
- # text is just the optional context and next l tokens
70
- if not echo:
71
- curr_json['text'] = gpt2_tokenizer.decode(total_sequences[batch_id][-l:], skip_special_tokens=True)
72
- else:
73
- curr_json['text'] = gpt2_tokenizer.decode(total_sequences[batch_id], skip_special_tokens=True)
74
-
75
- # fill the return json with the top tokens and probs to match the OpenAI return value.
76
- if num_log_probs is not None:
77
- curr_json['logprobs'] = {}
78
- curr_json['logprobs']['top_logprobs'] = []
79
- curr_json['logprobs']['token_logprobs'] = []
80
- curr_json['logprobs']['tokens'] = []
81
- if not echo:
82
- # cutoff the -1 here because the probs are shifted one over for LMs
83
- for current_element_top_log_probs, current_element_top_tokens in zip(top_log_probs[batch_id][:-1], top_tokens[batch_id][:-1]):
84
- # tokens is a list of the top token at each position
85
- curr_json['logprobs']['tokens'].append(gpt2_tokenizer.decode([current_element_top_tokens[0]]))
86
- # token_logprobs is a list of the logprob of the top token at each position
87
- curr_json['logprobs']['token_logprobs'].append(current_element_top_log_probs[0].item())
88
- # top_logprobs is a list of dicts for the top K tokens. with each entry being {'token_name': log_prob}
89
- temp = {}
90
- for log_prob, token in zip(current_element_top_log_probs, current_element_top_tokens):
91
- temp[gpt2_tokenizer.decode(token.item())] = log_prob.item()
92
- curr_json['logprobs']['top_logprobs'].append(temp)
93
- else:
94
- # same as not above but small tweaks
95
- # we add null to the front because for the GPT models, they have null probability for the first token
96
- # (for some reason they don't have an beginning of sentence token)
97
- curr_json['logprobs']['top_logprobs'].append('null')
98
- # cutoff the -1 here because the probs are shifted one over for LMs
99
- for index, (current_element_top_log_probs, current_element_top_tokens) in enumerate(zip(top_log_probs[batch_id][:-1], top_tokens[batch_id][:-1])):
100
- # skip padding tokens
101
- if total_sequences[batch_id][index].item() == 50256:
102
- continue
103
- temp = {}
104
- for log_prob, token in zip(current_element_top_log_probs, current_element_top_tokens):
105
- temp[gpt2_tokenizer.decode(token.item())] = log_prob.item()
106
- curr_json['logprobs']['top_logprobs'].append(temp)
107
- for index in range(len(probs[batch_id])):
108
- curr_json['logprobs']['tokens'].append(gpt2_tokenizer.decode([total_sequences[batch_id][index]]))
109
- curr_json['logprobs']['token_logprobs'].append('null')
110
- for index, log_probs_token_position_j in enumerate(logprobs[batch_id][:-1]):
111
- # probs are left shifted for LMs
112
- curr_json['logprobs']['token_logprobs'].append(log_probs_token_position_j[total_sequences[batch_id][index+1]])
113
-
114
- choices.append(curr_json)
115
- # print("curr_json=", curr_json)
116
- '''
117
- e.g.,
118
- num_tokens_to_predict=1
119
- curr_json= {
120
- 'text': ' I', # 当前生成的top词
121
- 'logprobs': {'top_logprobs': [{' I': -3.4267239570617676, '\n': -3.5073862075805664, ...], # top100词及其socre
122
- 'token_logprobs': [-3.4267239570617676], # 当前top词的score
123
- 'tokens': [' I']}
124
- }
125
- num_tokens_to_predict=2
126
- curr_json= {
127
- 'text': '\nThe', # 如果指定生成两个词,则为两个词
128
- 'logprobs': {'top_logprobs': [ # 两个位置对应的预测的score
129
- {'\n': -3.186706304550171, '\xa0': -3.222092390060425, ' We': -6.781067848205566, ...},
130
- {'The': -2.5251243114471436, '"': -2.857935667037964, ...],
131
- 'token_logprobs': [-3.186706304550171, -2.5251243114471436], # 生成的词的score
132
- 'tokens': ['\n', 'The']}
133
- }
134
- '''
135
- return_json['choices'] = choices
136
- # print("="*50)
137
- # print("return_json=", return_json)
138
- return return_json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tools/model_utils/parameter_freeze.py DELETED
@@ -1,126 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # @Time : 2023/02/18 02:07 p.m.
3
- # @Author : JianingWang
4
- # @File : parameter_freeze.py
5
-
6
- import torch
7
-
8
-
9
- """
10
- This is use for parameter fixing and unfreezing, which can be viewed as parameter-efficient settings.
11
- """
12
- class ParameterFreeze():
13
- # freeze all parameters
14
- def freeze_lm(self, model: torch.nn.Module):
15
- for name, param in model.named_parameters():
16
- param.requires_grad = False
17
- return model
18
-
19
- # freeze all parameters without cls / mlm head
20
- def freeze_lm_encoder(self, model: torch.nn.Module):
21
- for name, param in model.named_parameters():
22
- if "lm_head" in name or ("cls" in name):
23
- print(name)
24
- continue
25
- param.requires_grad = False
26
- return model
27
-
28
- # freeze all parameters without bias
29
- def freeze_lm_finetune_bias(self, model: torch.nn.Module):
30
- for name, param in model.named_parameters():
31
- if "bias" in name:
32
- print(name)
33
- continue
34
- param.requires_grad = False
35
- return model
36
-
37
- # freeze the component that user defined
38
- def freeze_lm_component(self, model: torch.nn.Module, component: str):
39
- if "attention" in component:
40
- for name, param in model.named_parameters():
41
- if "attention" in name:
42
- if "output" in component:
43
- if "output" in name:
44
- continue
45
- else:
46
- continue
47
- param.requires_grad = False
48
- model = self.unfreeze_classification_head(model)
49
- elif "feedforward" in component:
50
- for name, param in model.named_parameters():
51
- if "dense" in name and "attention" not in name:
52
- if "output" in component:
53
- if "output" in name:
54
- continue
55
- else:
56
- if "intermediate" in component:
57
- if "intermediate" in name:
58
- continue
59
- param.requires_grad = False
60
- model = self.unfreeze_classification_head(model)
61
- elif component == "adapter":
62
- for name, param in model.named_parameters():
63
- if "adapter" in name:
64
- continue
65
-
66
- param.requires_grad = False
67
- model = self.unfreeze_classification_head(model)
68
- elif "embedding" in component:
69
- for name, param in model.named_parameters():
70
- if "embedding" in name:
71
- continue
72
-
73
- param.requires_grad = False
74
- model = self.unfreeze_classification_head(model)
75
- elif "bias" in component:
76
- for name, param in model.named_parameters():
77
- if "bias" in name:
78
- continue
79
- param.requires_grad = False
80
- model = self.unfreeze_classification_head(model)
81
- elif "head" in component:
82
- for name, param in model.named_parameters():
83
- param.requires_grad = False
84
- model = self.unfreeze_classification_head(model)
85
-
86
- elif "prompt_emb" in component:
87
- for name, param in model.named_parameters():
88
- if "prompt_emb" in name:
89
- continue
90
- param.requires_grad = False
91
- return model
92
-
93
- # unfreeze cls head
94
- def unfreeze_classification_head(self, model: torch.nn.Module):
95
- for name, param in model.named_parameters():
96
- if "lm_head" in name or ("cls" in name) or ("classifier" in name):
97
- param.requires_grad = True
98
- return model
99
-
100
- # freeze k layers
101
- def freeze_lm_k_layers(self, model: torch.nn.Module, k):
102
- keep_layers = []
103
- update_parameters = []
104
- for i in range(k):
105
- keep_layers.append("layer."+str(23-i))
106
-
107
- for name, param in model.named_parameters():
108
- update = False
109
- for layer_num in keep_layers:
110
- if layer_num in name:
111
- if "dense" in name and "attention" not in name:
112
- if "output" in name:
113
- print(name)
114
- update_parameters.append(name)
115
- update = True
116
-
117
- if not update:
118
- param.requires_grad = False
119
- model = self.unfreeze_classification_head(model)
120
- return model
121
-
122
-
123
- def unfreeze_lm(self, model: torch.nn.Module):
124
- for param in model.parameters():
125
- param.requires_grad = True
126
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tools/model_utils/uncertainty.py DELETED
@@ -1,137 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # @Time : 2023/04/18 08:11 p.m.
3
- # @Author : JianingWang
4
- # @File : uncertainty.py
5
-
6
- from sklearn.utils import shuffle
7
- import logging
8
- import numpy as np
9
- import os
10
- import random
11
-
12
-
13
- logger = logging.getLogger(__name__)
14
-
15
-
16
- def get_BALD_acquisition(y_T):
17
-
18
- expected_entropy = - np.mean(np.sum(y_T * np.log(y_T + 1e-10), axis=-1), axis=0)
19
- expected_p = np.mean(y_T, axis=0)
20
- entropy_expected_p = - np.sum(expected_p * np.log(expected_p + 1e-10), axis=-1)
21
- return (entropy_expected_p - expected_entropy)
22
-
23
-
24
- def sample_by_bald_difficulty(tokenizer, X, y_mean, y_var, y, num_samples, num_classes, y_T):
25
-
26
- logger.info ("Sampling by difficulty BALD acquisition function")
27
- BALD_acq = get_BALD_acquisition(y_T)
28
- p_norm = np.maximum(np.zeros(len(BALD_acq)), BALD_acq)
29
- p_norm = p_norm / np.sum(p_norm)
30
- indices = np.random.choice(len(X['input_ids']), num_samples, p=p_norm, replace=False)
31
- X_s = {"input_ids": X["input_ids"][indices], "token_type_ids": X["token_type_ids"][indices], "attention_mask": X["attention_mask"][indices]}
32
- y_s = y[indices]
33
- w_s = y_var[indices][:,0]
34
- return X_s, y_s, w_s
35
-
36
-
37
- def sample_by_bald_easiness(tokenizer, X, y_mean, y_var, y, num_samples, num_classes, y_T):
38
-
39
- logger.info ("Sampling by easy BALD acquisition function")
40
- BALD_acq = get_BALD_acquisition(y_T)
41
- p_norm = np.maximum(np.zeros(len(BALD_acq)), (1. - BALD_acq)/np.sum(1. - BALD_acq))
42
- p_norm = p_norm / np.sum(p_norm)
43
- logger.info (p_norm[:10])
44
- indices = np.random.choice(len(X['input_ids']), num_samples, p=p_norm, replace=False)
45
- X_s = {"input_ids": X["input_ids"][indices], "token_type_ids": X["token_type_ids"][indices], "attention_mask": X["attention_mask"][indices]}
46
- y_s = y[indices]
47
- w_s = y_var[indices][:,0]
48
- return X_s, y_s, w_s
49
-
50
-
51
- def sample_by_bald_class_easiness(tokenizer, X, y_mean, y_var, y, num_samples, num_classes, y_T):
52
-
53
- logger.info ("Sampling by easy BALD acquisition function per class")
54
- BALD_acq = get_BALD_acquisition(y_T)
55
- BALD_acq = (1. - BALD_acq)/np.sum(1. - BALD_acq)
56
- logger.info (BALD_acq)
57
- samples_per_class = num_samples // num_classes
58
- X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, X_s_mask_pos, y_s, w_s = [], [], [], [], [], []
59
-
60
- for label in range(num_classes):
61
- # X_input_ids, X_token_type_ids, X_attention_mask = np.array(X['input_ids'])[y == label], np.array(X['token_type_ids'])[y == label], np.array(X['attention_mask'])[y == label]
62
- X_input_ids, X_attention_mask = np.array(X['input_ids'])[y == label], np.array(X['attention_mask'])[y == label]
63
- if "token_type_ids" in X.features:
64
- X_token_type_ids = np.array(X['token_type_ids'])[y == label]
65
- if "mask_pos" in X.features:
66
- X_mask_pos = np.array(X['mask_pos'])[y == label]
67
- y_ = y[y==label]
68
- y_var_ = y_var[y == label]
69
- # p = y_mean[y == label]
70
- p_norm = BALD_acq[y==label]
71
- p_norm = np.maximum(np.zeros(len(p_norm)), p_norm)
72
- p_norm = p_norm/np.sum(p_norm)
73
- if len(X_input_ids) < samples_per_class:
74
- logger.info ("Sampling with replacement.")
75
- replace = True
76
- else:
77
- replace = False
78
- if len(X_input_ids) == 0: # add by wjn
79
- continue
80
- indices = np.random.choice(len(X_input_ids), samples_per_class, p=p_norm, replace=replace)
81
- X_s_input_ids.extend(X_input_ids[indices])
82
- # X_s_token_type_ids.extend(X_token_type_ids[indices])
83
- X_s_attention_mask.extend(X_attention_mask[indices])
84
- if "token_type_ids" in X.features:
85
- X_s_token_type_ids.extend(X_token_type_ids[indices])
86
- if "mask_pos" in X.features:
87
- X_s_mask_pos.extend(X_mask_pos[indices])
88
- y_s.extend(y_[indices])
89
- w_s.extend(y_var_[indices][:,0])
90
-
91
- # X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s)
92
- if "token_type_ids" in X.features and "mask_pos" not in X.features:
93
- X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s)
94
- elif "token_type_ids" not in X.features and "mask_pos" in X.features:
95
- X_s_input_ids, X_s_mask_pos, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_mask_pos, X_s_attention_mask, y_s, w_s)
96
- elif "token_type_ids" in X.features and "mask_pos" in X.features:
97
- X_s_input_ids, X_s_token_type_ids, X_s_mask_pos, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_token_type_ids, X_s_mask_pos, X_s_attention_mask, y_s, w_s)
98
- else:
99
- X_s_input_ids, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_attention_mask, y_s, w_s)
100
-
101
- pseudo_labeled_input = {
102
- 'input_ids': np.array(X_s_input_ids),
103
- 'attention_mask': np.array(X_s_attention_mask)
104
- }
105
- if "token_type_ids" in X.features:
106
- pseudo_labeled_input['token_type_ids'] = np.array(X_s_token_type_ids)
107
- if "mask_pos" in X.features:
108
- pseudo_labeled_input['mask_pos'] = np.array(X_s_mask_pos)
109
- return pseudo_labeled_input, np.array(y_s), np.array(w_s)
110
-
111
-
112
- def sample_by_bald_class_difficulty(tokenizer, X, y_mean, y_var, y, num_samples, num_classes, y_T):
113
-
114
- logger.info ("Sampling by difficulty BALD acquisition function per class")
115
- BALD_acq = get_BALD_acquisition(y_T)
116
- samples_per_class = num_samples // num_classes
117
- X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s = [], [], [], [], []
118
- for label in range(num_classes):
119
- X_input_ids, X_token_type_ids, X_attention_mask = X['input_ids'][y == label], X['token_type_ids'][y == label], X['attention_mask'][y == label]
120
- y_ = y[y==label]
121
- y_var_ = y_var[y == label]
122
- p_norm = BALD_acq[y==label]
123
- p_norm = np.maximum(np.zeros(len(p_norm)), p_norm)
124
- p_norm = p_norm/np.sum(p_norm)
125
- if len(X_input_ids) < samples_per_class:
126
- replace = True
127
- logger.info ("Sampling with replacement.")
128
- else:
129
- replace = False
130
- indices = np.random.choice(len(X_input_ids), samples_per_class, p=p_norm, replace=replace)
131
- X_s_input_ids.extend(X_input_ids[indices])
132
- X_s_token_type_ids.extend(X_token_type_ids[indices])
133
- X_s_attention_mask.extend(X_attention_mask[indices])
134
- y_s.extend(y_[indices])
135
- w_s.extend(y_var_[indices][:,0])
136
- X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s = shuffle(X_s_input_ids, X_s_token_type_ids, X_s_attention_mask, y_s, w_s)
137
- return {'input_ids': np.array(X_s_input_ids), 'token_type_ids': np.array(X_s_token_type_ids), 'attention_mask': np.array(X_s_attention_mask)}, np.array(y_s), np.array(w_s)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tools/processing_utils/common.py DELETED
@@ -1,38 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # @Time : 2021/12/2 5:41 p.m.
3
- # @Author : JianingWang
4
- # @File : common.py
5
-
6
-
7
- def is_chinese_char(cp):
8
- """Checks whether CP is the codepoint of a CJK character."""
9
- # This defines a "chinese character" as anything in the CJK Unicode block:
10
- # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
11
- #
12
- # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
13
- # despite its name. The modern Korean Hangul alphabet is a different block,
14
- # as is Japanese Hiragana and Katakana. Those alphabets are used to write
15
- # space-separated words, so they are not treated specially and handled
16
- # like the all of the other languages.
17
- if (
18
- (0x4E00 <= cp <= 0x9FFF)
19
- or (0x3400 <= cp <= 0x4DBF) #
20
- or (0x20000 <= cp <= 0x2A6DF) #
21
- or (0x2A700 <= cp <= 0x2B73F) #
22
- or (0x2B740 <= cp <= 0x2B81F) #
23
- or (0x2B820 <= cp <= 0x2CEAF) #
24
- or (0xF900 <= cp <= 0xFAFF)
25
- or (0x2F800 <= cp <= 0x2FA1F) #
26
- ): #
27
- return True
28
-
29
- return False
30
-
31
-
32
- def is_chinese(word: str):
33
- # word like "180" or "身高" or "神"
34
- for char in word:
35
- char = ord(char)
36
- if not is_chinese_char(char):
37
- return 0
38
- return 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tools/processing_utils/sampler.py DELETED
@@ -1,26 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # @Time : 2021/12/2 5:41 p.m.
3
- # @Author : JianingWang
4
- # @File : sampler.py
5
-
6
- import numpy as np
7
- from typing import Optional
8
-
9
- """
10
- random sampling for each label
11
- """
12
- def random_sampling(raw_datasets, num_examples_per_label: Optional[int]=16):
13
- label_list = raw_datasets["label"] # [0, 1, 0, 0, ...]
14
- label_dict = dict()
15
- # denote index of each label
16
- for ei, label in enumerate(label_list):
17
- if label not in label_dict.keys():
18
- label_dict[label] = list()
19
- label_dict[label].append(ei)
20
- # random sample k examples of each class
21
- few_example_ids = list()
22
- for label, eid_list in label_dict.items():
23
- idxs = np.random.choice(len(eid_list), size=num_examples_per_label, replace=False)
24
- selected_eids = [eid_list[i] for i in idxs]
25
- few_example_ids.extend(selected_eids)
26
- return few_example_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tools/processing_utils/tokenizer/JiebaTokenizer.py DELETED
@@ -1,24 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # @Time : 2021/12/8 12:07 a.m.
3
- # @Author : JianingWang
4
- # @File : JiebaTokenizer
5
-
6
- import jieba
7
- from transformers import BertTokenizer
8
-
9
-
10
- class JiebaTokenizer(BertTokenizer):
11
- def __init__(
12
- self, pre_tokenizer=lambda x: jieba.cut(x, HMM=False), *args, **kwargs
13
- ):
14
- super().__init__(*args, **kwargs)
15
- self.pre_tokenizer = pre_tokenizer
16
-
17
- def _tokenize(self, text, *arg, **kwargs):
18
- split_tokens = []
19
- for text in self.pre_tokenizer(text):
20
- if text in self.vocab:
21
- split_tokens.append(text)
22
- else:
23
- split_tokens.extend(super()._tokenize(text))
24
- return split_tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tools/processing_utils/tokenizer/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # @Time : 2021/12/8 12:07 上午
3
- # @Author : JianingWang
4
- # @File : __init__.py
 
 
 
 
 
models/tools/processing_utils/tokenizer/tokenizer_utils.py DELETED
@@ -1,19 +0,0 @@
1
- from transformers import AutoTokenizer
2
-
3
- """
4
- obtain special tokens
5
- """
6
- def get_special_token_mapping(tokenizer: AutoTokenizer):
7
- if "t5" in type(tokenizer).__name__.lower():
8
- special_token_mapping = {
9
- "cls": 3, "mask": 32099, "sep": tokenizer.eos_token_id,
10
- "sep+": tokenizer.eos_token_id,
11
- "pseudo_token": tokenizer.unk_token_id
12
- }
13
- else:
14
- special_token_mapping = {
15
- "cls": tokenizer.cls_token_id, "mask": tokenizer.mask_token_id, "sep": tokenizer.sep_token_id,
16
- "sep+": tokenizer.sep_token_id,
17
- "pseudo_token": tokenizer.unk_token_id
18
- }
19
- return special_token_mapping
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tools/runner_utils/__init__.py DELETED
File without changes
models/tools/runner_utils/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (140 Bytes)
 
models/tools/runner_utils/__pycache__/log_util.cpython-38.pyc DELETED
Binary file (969 Bytes)
 
models/tools/runner_utils/conifg_extensive.py DELETED
@@ -1,15 +0,0 @@
1
- from transformers import AutoConfig
2
- from config import ModelArguments
3
-
4
-
5
- # add external config.
6
- def config_extensive(hf_config: AutoConfig, model_config: ModelArguments):
7
- hf_config.use_prompt_for_cls = model_config.use_prompt_for_cls
8
- hf_config.use_freezing = model_config.use_freezing
9
- hf_config.adapter_choice = model_config.adapter_choice
10
- hf_config.adapter_dim = model_config.adapter_dim
11
- hf_config.pre_seq_len = model_config.pre_seq_len
12
- hf_config.prefix_projection = model_config.prefix_projection
13
- hf_config.prefix_hidden_size = model_config.prefix_hidden_size
14
- hf_config.hidden_dropout_prob = model_config.hidden_dropout_prob
15
- return hf_config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tools/runner_utils/log_util.py DELETED
@@ -1,30 +0,0 @@
1
- import sys
2
- import logging
3
- import datasets
4
- import transformers
5
-
6
-
7
- def init_logger(log_file, log_level, dist_rank):
8
- datasets.utils.logging.set_verbosity(log_level)
9
- transformers.utils.logging.set_verbosity(log_level)
10
- transformers.utils.logging.enable_default_handler()
11
- transformers.utils.logging.enable_explicit_format()
12
- datasets.utils.logging.disable_propagation()
13
- # transformers.utils.logging.enable_propagation()
14
-
15
- logger = logging.getLogger("")
16
- log_format = logging.Formatter(fmt="[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
17
- logger.setLevel(log_level)
18
- console_handler = logging.StreamHandler(sys.stderr)
19
- console_handler.setFormatter(log_format)
20
- logger.addHandler(console_handler)
21
- # transformer_logger = logging.getLogger("transformers")
22
- # transformer_logger.handlers = []
23
- # transformer_logger.propagate = True
24
-
25
- if dist_rank in [-1, 0]:
26
- file_handler = logging.FileHandler(log_file, mode="a")
27
- file_handler.setLevel(log_level)
28
- file_handler.setFormatter(log_format)
29
- logger.addHandler(file_handler)
30
- logging.getLogger("transformers").addHandler(file_handler)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tools/runner_utils/retrying.py DELETED
@@ -1,288 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # @Time : 2021/12/24 4:05 p.m.
3
- # @Author : JianingWang
4
- # @File : retrying.py
5
-
6
- import random
7
- import six
8
- import sys
9
- import time
10
- import traceback
11
-
12
-
13
- MAX_WAIT = 1073741823
14
-
15
-
16
- def _retry_if_exception_of_type(retryable_types):
17
- def _retry_if_exception_these_types(exception):
18
- return isinstance(exception, retryable_types)
19
- return _retry_if_exception_these_types
20
-
21
-
22
- def retry(*dargs, **dkw):
23
- """
24
- Decorator function that instantiates the Retrying object
25
- @param *dargs: positional arguments passed to Retrying object
26
- @param **dkw: keyword arguments passed to the Retrying object
27
- """
28
- # support both @retry and @retry() as valid syntax
29
- if len(dargs) == 1 and callable(dargs[0]):
30
- def wrap_simple(f):
31
-
32
- @six.wraps(f)
33
- def wrapped_f(*args, **kw):
34
- return Retrying().call(f, *args, **kw)
35
-
36
- return wrapped_f
37
-
38
- return wrap_simple(dargs[0])
39
-
40
- else:
41
- def wrap(f):
42
-
43
- @six.wraps(f)
44
- def wrapped_f(*args, **kw):
45
- return Retrying(*dargs, **dkw).call(f, *args, **kw)
46
-
47
- return wrapped_f
48
-
49
- return wrap
50
-
51
-
52
- class Retrying(object):
53
-
54
- def __init__(self,
55
- stop=None, wait=None,
56
- stop_max_attempt_number=None,
57
- stop_max_delay=None,
58
- wait_fixed=None,
59
- wait_random_min=None, wait_random_max=None,
60
- wait_incrementing_start=None, wait_incrementing_increment=None,
61
- wait_incrementing_max=None,
62
- wait_exponential_multiplier=None, wait_exponential_max=None,
63
- retry_on_exception=None,
64
- retry_on_result=None,
65
- wrap_exception=False,
66
- stop_func=None,
67
- wait_func=None,
68
- wait_jitter_max=None,
69
- before_attempts=None,
70
- after_attempts=None,
71
- skip_raise=False):
72
-
73
- self._stop_max_attempt_number = 5 if stop_max_attempt_number is None else stop_max_attempt_number
74
- self._stop_max_delay = 100 if stop_max_delay is None else stop_max_delay
75
- self._wait_fixed = 1000 if wait_fixed is None else wait_fixed
76
- self._wait_random_min = 0 if wait_random_min is None else wait_random_min
77
- self._wait_random_max = 1000 if wait_random_max is None else wait_random_max
78
- self._wait_incrementing_start = 0 if wait_incrementing_start is None else wait_incrementing_start
79
- self._wait_incrementing_increment = 100 if wait_incrementing_increment is None else wait_incrementing_increment
80
- self._wait_exponential_multiplier = 1 if wait_exponential_multiplier is None else wait_exponential_multiplier
81
- self._wait_exponential_max = MAX_WAIT if wait_exponential_max is None else wait_exponential_max
82
- self._wait_incrementing_max = MAX_WAIT if wait_incrementing_max is None else wait_incrementing_max
83
- self._wait_jitter_max = 0 if wait_jitter_max is None else wait_jitter_max
84
- self._before_attempts = before_attempts
85
- self._after_attempts = after_attempts
86
- self._skip_raise = skip_raise
87
-
88
- # stop behavior
89
- stop_funcs = []
90
- if stop_max_attempt_number is not None:
91
- stop_funcs.append(self.stop_after_attempt)
92
-
93
- if stop_max_delay is not None:
94
- stop_funcs.append(self.stop_after_delay)
95
-
96
- if stop_func is not None:
97
- self.stop = stop_func
98
-
99
- elif stop is None:
100
- self.stop = lambda attempts, delay: any(f(attempts, delay) for f in stop_funcs)
101
-
102
- else:
103
- self.stop = getattr(self, stop)
104
-
105
- # wait behavior
106
- wait_funcs = [lambda *args, **kwargs: 0]
107
- if wait_fixed is not None:
108
- wait_funcs.append(self.fixed_sleep)
109
-
110
- if wait_random_min is not None or wait_random_max is not None:
111
- wait_funcs.append(self.random_sleep)
112
-
113
- if wait_incrementing_start is not None or wait_incrementing_increment is not None:
114
- wait_funcs.append(self.incrementing_sleep)
115
-
116
- if wait_exponential_multiplier is not None or wait_exponential_max is not None:
117
- wait_funcs.append(self.exponential_sleep)
118
-
119
- if wait_func is not None:
120
- self.wait = wait_func
121
-
122
- elif wait is None:
123
- self.wait = lambda attempts, delay: max(f(attempts, delay) for f in wait_funcs)
124
-
125
- else:
126
- self.wait = getattr(self, wait)
127
-
128
- # retry on exception filter
129
- if retry_on_exception is None:
130
- self._retry_on_exception = self.always_reject
131
- else:
132
- # this allows for providing a tuple of exception types that
133
- # should be allowed to retry on, and avoids having to create
134
- # a callback that does the same thing
135
- if isinstance(retry_on_exception, (tuple)):
136
- retry_on_exception = _retry_if_exception_of_type(
137
- retry_on_exception)
138
- self._retry_on_exception = retry_on_exception
139
-
140
- # retry on result filter
141
- if retry_on_result is None:
142
- self._retry_on_result = self.never_reject
143
- else:
144
- self._retry_on_result = retry_on_result
145
-
146
- self._wrap_exception = wrap_exception
147
-
148
- def stop_after_attempt(self, previous_attempt_number, delay_since_first_attempt_ms):
149
- """Stop after the previous attempt >= stop_max_attempt_number."""
150
- return previous_attempt_number >= self._stop_max_attempt_number
151
-
152
- def stop_after_delay(self, previous_attempt_number, delay_since_first_attempt_ms):
153
- """Stop after the time from the first attempt >= stop_max_delay."""
154
- return delay_since_first_attempt_ms >= self._stop_max_delay
155
-
156
- @staticmethod
157
- def no_sleep(previous_attempt_number, delay_since_first_attempt_ms):
158
- """Don"t sleep at all before retrying."""
159
- return 0
160
-
161
- def fixed_sleep(self, previous_attempt_number, delay_since_first_attempt_ms):
162
- """Sleep a fixed amount of time between each retry."""
163
- return self._wait_fixed
164
-
165
- def random_sleep(self, previous_attempt_number, delay_since_first_attempt_ms):
166
- """Sleep a random amount of time between wait_random_min and wait_random_max"""
167
- return random.randint(self._wait_random_min, self._wait_random_max)
168
-
169
- def incrementing_sleep(self, previous_attempt_number, delay_since_first_attempt_ms):
170
- """
171
- Sleep an incremental amount of time after each attempt, starting at
172
- wait_incrementing_start and incrementing by wait_incrementing_increment
173
- """
174
- result = self._wait_incrementing_start + (self._wait_incrementing_increment * (previous_attempt_number - 1))
175
- if result > self._wait_incrementing_max:
176
- result = self._wait_incrementing_max
177
- if result < 0:
178
- result = 0
179
- return result
180
-
181
- def exponential_sleep(self, previous_attempt_number, delay_since_first_attempt_ms):
182
- exp = 2 ** previous_attempt_number
183
- result = self._wait_exponential_multiplier * exp
184
- if result > self._wait_exponential_max:
185
- result = self._wait_exponential_max
186
- if result < 0:
187
- result = 0
188
- return result
189
-
190
- @staticmethod
191
- def never_reject(result):
192
- return False
193
-
194
- @staticmethod
195
- def always_reject(result):
196
- return True
197
-
198
- def should_reject(self, attempt):
199
- reject = False
200
- if attempt.has_exception:
201
- reject |= self._retry_on_exception(attempt.value[1])
202
- else:
203
- reject |= self._retry_on_result(attempt.value)
204
-
205
- return reject
206
-
207
- def call(self, fn, *args, **kwargs):
208
- start_time = int(round(time.time() * 1000))
209
- attempt_number = 1
210
- while True:
211
- if self._before_attempts:
212
- self._before_attempts(attempt_number)
213
-
214
- try:
215
- attempt = Attempt(fn(*args, **kwargs), attempt_number, False)
216
- except:
217
- tb = sys.exc_info()
218
- attempt = Attempt(tb, attempt_number, True)
219
-
220
- if not self.should_reject(attempt):
221
- return attempt.get(self._wrap_exception)
222
-
223
- if self._after_attempts:
224
- self._after_attempts(attempt_number)
225
-
226
- delay_since_first_attempt_ms = int(round(time.time() * 1000)) - start_time
227
- if self.stop(attempt_number, delay_since_first_attempt_ms):
228
- if not self._wrap_exception and attempt.has_exception:
229
- # get() on an attempt with an exception should cause it to be raised, but raise just in case
230
- if not self._skip_raise:
231
- raise attempt.get()
232
- else:
233
- break
234
- else:
235
- raise RetryError(attempt)
236
- else:
237
- sleep = self.wait(attempt_number, delay_since_first_attempt_ms)
238
- if self._wait_jitter_max:
239
- jitter = random.random() * self._wait_jitter_max
240
- sleep = sleep + max(0, jitter)
241
- time.sleep(sleep / 1000.0)
242
-
243
- attempt_number += 1
244
-
245
-
246
- class Attempt(object):
247
- """
248
- An Attempt encapsulates a call to a target function that may end as a
249
- normal return value from the function or an Exception depending on what
250
- occurred during the execution.
251
- """
252
-
253
- def __init__(self, value, attempt_number, has_exception):
254
- self.value = value
255
- self.attempt_number = attempt_number
256
- self.has_exception = has_exception
257
-
258
- def get(self, wrap_exception=False):
259
- """
260
- Return the return value of this Attempt instance or raise an Exception.
261
- If wrap_exception is true, this Attempt is wrapped inside of a
262
- RetryError before being raised.
263
- """
264
- if self.has_exception:
265
- if wrap_exception:
266
- raise RetryError(self)
267
- else:
268
- six.reraise(self.value[0], self.value[1], self.value[2])
269
- else:
270
- return self.value
271
-
272
- def __repr__(self):
273
- if self.has_exception:
274
- return "Attempts: {0}, Error:\n{1}".format(self.attempt_number, "".join(traceback.format_tb(self.value[2])))
275
- else:
276
- return "Attempts: {0}, Value: {1}".format(self.attempt_number, self.value)
277
-
278
-
279
- class RetryError(Exception):
280
- """
281
- A RetryError encapsulates the last Attempt instance right before giving up.
282
- """
283
-
284
- def __init__(self, last_attempt):
285
- self.last_attempt = last_attempt
286
-
287
- def __str__(self):
288
- return "RetryError[{0}]".format(self.last_attempt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tools/runner_utils/set_seed.py DELETED
@@ -1,21 +0,0 @@
1
- import torch
2
- import random
3
- import numpy as np
4
-
5
- from transformers.utils import (
6
- is_tf_available,
7
- is_torch_available,
8
- )
9
-
10
- def set_seed(seed_value: int):
11
- """
12
- Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` and/or `tf` (if installed).
13
-
14
- Args:
15
- seed (`int`): The seed to set.
16
- """
17
- random.seed(seed_value)
18
- np.random.seed(seed_value)
19
- if is_torch_available():
20
- torch.manual_seed(seed_value)
21
- torch.cuda.manual_seed_all(seed_value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/tools/runner_utils/timecost.py DELETED
@@ -1,20 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # @Time : 2022/3/11 3:06 p.m.
3
- # @Author : JianingWang
4
- # @File : time
5
-
6
- import time
7
- import logging
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
-
12
- def timecost(method):
13
- def timed(*args, **kw):
14
- ts = time.time()
15
- result = method(*args, **kw)
16
- te = time.time()
17
- logger.info("%r %2.2f ms" % (method.__name__, (te - ts) * 1000))
18
- return result
19
-
20
- return timed