Spaces:
Sleeping
Sleeping
DeepLearning101
commited on
Commit
•
08f4077
1
Parent(s):
62c36b1
第一次測試佈署更新
Browse files- app.py +258 -0
- applications/information_extraction/HugIE/api_test.py +234 -0
- models/__init__.py +292 -0
- requirements.txt +11 -0
- wjn1996-hugnlp-hugie-large-zh/config.json +38 -0
- wjn1996-hugnlp-hugie-large-zh/gitattributes.txt +34 -0
- wjn1996-hugnlp-hugie-large-zh/special_tokens_map.json +7 -0
- wjn1996-hugnlp-hugie-large-zh/tokenizer.json +0 -0
- wjn1996-hugnlp-hugie-large-zh/tokenizer_config.json +14 -0
- wjn1996-hugnlp-hugie-large-zh/vocab.txt +0 -0
app.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 2023/05/30
|
3 |
+
# @Author : TonTon H.-D. Huang Ph.D.
|
4 |
+
# @Web :http://TWMAN.ORG
|
5 |
+
# @EMail :TonTon@TWMAN.ORG
|
6 |
+
# @File : HugIE.py
|
7 |
+
# @Description :毋需重新訓練的醫療診斷書醫囑文字分析
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
import json, re
|
11 |
+
from applications.information_extraction.HugIE.api_test import HugIEAPI
|
12 |
+
from dateutil import parser
|
13 |
+
from datetime import datetime
|
14 |
+
|
15 |
+
model_type = "bert"
|
16 |
+
hugie_model_name_or_path = "./wjn1996-hugnlp-hugie-large-zh/" #如果不能連網,請自行下載並設定路徑
|
17 |
+
hugie = HugIEAPI(model_type, hugie_model_name_or_path)
|
18 |
+
|
19 |
+
def convert_to_ROC_date(date): #只轉換年月日等日期
|
20 |
+
date_regex = r'(\d{3,4}[-::/.年]\d{1,2}[-::/.月]\d{1,2}[日]?)'
|
21 |
+
time_regex = r'(\d{1,2}[-::/.時]\d{1,2}[-::/.分]\d{1,2}[秒]?)'
|
22 |
+
|
23 |
+
date_match = re.search(date_regex, date)
|
24 |
+
|
25 |
+
if date_match:
|
26 |
+
date_part = date_match.group(1)
|
27 |
+
parsed_date = parser.parse(date_part, fuzzy=True)
|
28 |
+
if str(date_part).startswith('20'):
|
29 |
+
ROC_year = int(date_part[:4])- 1911
|
30 |
+
else:
|
31 |
+
ROC_year = int(date_part[:3])
|
32 |
+
|
33 |
+
ROC_month = parsed_date.month
|
34 |
+
ROC_day = parsed_date.day
|
35 |
+
ROC_date = f"{ROC_year:03d}{ROC_month:02d}{ROC_day:02d}"
|
36 |
+
return ROC_date
|
37 |
+
|
38 |
+
else:
|
39 |
+
return date
|
40 |
+
|
41 |
+
def convert_to_ROC_time(time): #只處理時間,看 ketword 和 relation 可以發現只有 relation2 才會需要處理時間
|
42 |
+
time_regex = r'(\d{1,2}[-::/.時]\d{1,2}[-::/.分](?:\d{1,2}[秒])?)'
|
43 |
+
time_match = re.search(time_regex, time)
|
44 |
+
if time_match:
|
45 |
+
|
46 |
+
time_part = time_match.group(1)
|
47 |
+
|
48 |
+
try:
|
49 |
+
parsed_time = datetime.strptime(time_part, "%H時%M分%S秒")
|
50 |
+
except ValueError:
|
51 |
+
parsed_time = datetime.strptime(time_part, "%H時%M分")
|
52 |
+
parsed_time = parsed_time.replace(second=0)
|
53 |
+
ROC_time = parsed_time.strftime("%H%M%S")
|
54 |
+
return ROC_time
|
55 |
+
|
56 |
+
def extract_information(text):
|
57 |
+
|
58 |
+
|
59 |
+
keywords = { #視情況自己新增調整,不用重新訓練
|
60 |
+
'Hospital1': ['入院', '住入本院', '普通病房', '住院', '轉入一般病房', '入住本院'], # 住院相關,普通病房
|
61 |
+
'Hospital2': ['出院', '離院'], # 出院相關,普通病房
|
62 |
+
'Burn1': ['燒燙傷'], # 燒燙傷類病房
|
63 |
+
'Burn2': ['燒燙傷'], # 燒燙傷類病房
|
64 |
+
'ICU1': ['加護病房', '住院加護病房'],
|
65 |
+
'ICU2': ['轉普通病房', '轉入普通病房', '轉至普通病房', '轉回一般病房', '轉至兒科一般病房'],
|
66 |
+
'exclude_Clinic': ['門診追蹤', '門診複查', '門診持續追蹤', '急診求治', '繼續追蹤', '急診就診'],
|
67 |
+
'Clinic': ['牙科', '來院門診', '門診就診', '看診', '回診', '門診回診', '婦科就診', '門診治療', '來院就診', '本院診療', "本院門診", "經門診", "門診就醫", "由門診", "接受門診", "至診就診", "至門診複診"],
|
68 |
+
'Operation1': ['手術', '切除術', '置放術', '切片術', '幹細胞'],
|
69 |
+
'Operation2': ['左側乳房部分切除併前哨淋巴清除手術', '手術', '切除術', '置放術', '切片術', '幹細胞'],
|
70 |
+
'Emergency1': ['急診'],
|
71 |
+
'Emergency2': ['住入加護病房'],
|
72 |
+
'Chemotherapy': ['化學治療', '化療', '靜脈注射免疫藥物及標靶藥物治療'],
|
73 |
+
'Cancer': ['罹癌'],
|
74 |
+
'Radiation': ['放射線', '放射']
|
75 |
+
}
|
76 |
+
|
77 |
+
relations = {
|
78 |
+
'Hospital1': {'entity': '住院A', 'relation1': '開始日期'},
|
79 |
+
'Hospital2': {'entity': '住院A', 'relation1': '結束日期'},
|
80 |
+
'Burn1': {'entity': '燒燙傷病房B', 'relation1': '開始日期'},
|
81 |
+
'Burn2': {'entity': '燒燙傷病房B', 'relation1': '結束日期'},
|
82 |
+
'ICU1': {'entity': '加護病房C', 'relation1': '開始日期'},
|
83 |
+
'ICU2': {'entity': '加護病房C', 'relation1': '結束日期'},
|
84 |
+
'exclude_Clinic': {'entity': None},
|
85 |
+
'Clinic': {'entity': '門診D', 'relation1': '日期'},
|
86 |
+
'Operation1': {'entity': '手術F', 'relation1': '日期'},
|
87 |
+
'Operation2': {'entity': '手術F', 'relation1': '手術項目'},
|
88 |
+
'Emergency1': {'entity': '急診G', 'relation1': '開始日期', 'relation2': '開始時間'},
|
89 |
+
'Emergency2': {'entity': '急診G', 'relation1': '結束日期', 'relation2': '終止時間'},
|
90 |
+
'Chemotherapy': {'entity': '癌症化療H', 'relation1': '起訖日'},
|
91 |
+
'Cancer': {'entity': '罹癌I', 'relation1': '起訖日'},
|
92 |
+
'Radiation': {'entity': '癌症放射線J', 'relation1': '起訖日'}
|
93 |
+
}
|
94 |
+
|
95 |
+
#A:住院、B:燒燙傷、C:加護病房、D:門診、F:手術、G:急診、H:癌症化療、I:罹癌、J:癌症放射線
|
96 |
+
|
97 |
+
|
98 |
+
results = []
|
99 |
+
|
100 |
+
for entity, keyword_list in keywords.items():
|
101 |
+
output = {
|
102 |
+
'entity': relations[entity]['entity'],
|
103 |
+
'relations': {}
|
104 |
+
}
|
105 |
+
|
106 |
+
for keyword in keyword_list:
|
107 |
+
if keyword in keywords['exclude_Clinic']:
|
108 |
+
continue
|
109 |
+
|
110 |
+
if keyword in text and entity in relations:
|
111 |
+
|
112 |
+
entity_relations = relations[entity]
|
113 |
+
relation1 = entity_relations.get('relation1') # 取得關係1
|
114 |
+
relation2 = entity_relations.get('relation2') # 取得關係2
|
115 |
+
|
116 |
+
if relation1:
|
117 |
+
predictions, topk_predictions = hugie.request(text, keyword, relation=relation1)
|
118 |
+
if predictions[0]: # 如果有預測結果
|
119 |
+
|
120 |
+
for prediction in predictions[0]:
|
121 |
+
date_prediction = convert_to_ROC_date(prediction)
|
122 |
+
|
123 |
+
if relation1 == '開始日期':
|
124 |
+
relation_label = '受理_起始日'
|
125 |
+
output['relations'].setdefault(relation_label, {
|
126 |
+
'relation': relation_label,
|
127 |
+
'predictions': []
|
128 |
+
})
|
129 |
+
|
130 |
+
if date_prediction[:7] not in output['relations'][relation_label]['predictions']:
|
131 |
+
output['relations'][relation_label]['predictions'].append(date_prediction[:7])
|
132 |
+
elif date_prediction not in output['relations'][relation_label]['predictions']:
|
133 |
+
output['relations'][relation_label]['predictions'].append(date_prediction)
|
134 |
+
|
135 |
+
|
136 |
+
elif relation1 == '結束日期':
|
137 |
+
relation_label = '受理_終止日'
|
138 |
+
output['relations'].setdefault(relation_label, {
|
139 |
+
'relation': relation_label,
|
140 |
+
'predictions': []
|
141 |
+
})
|
142 |
+
date_pattern = r"1[0-9]\d{3}(?:0[1-9]|1[0-2])(?:0[1-9]|[1-2]\d|3[01])(?:\d{4-6})?$" #抓年月日時分秒,懶得再修了
|
143 |
+
match = re.match(date_pattern, date_prediction[:7])
|
144 |
+
if match:
|
145 |
+
if date_prediction[:7] not in output['relations'][relation_label]['predictions']:
|
146 |
+
output['relations'][relation_label]['predictions'].append(date_prediction[:7])
|
147 |
+
else:
|
148 |
+
if date_prediction not in output['relations'][relation_label]['predictions']:
|
149 |
+
output['relations'][relation_label]['predictions'].append(date_prediction)
|
150 |
+
|
151 |
+
|
152 |
+
elif relation1 in ['起訖日', '日期']:
|
153 |
+
relation_label = '受理_起始日'
|
154 |
+
output['relations'].setdefault(relation_label, {
|
155 |
+
'relation': relation_label,
|
156 |
+
'predictions': []
|
157 |
+
})
|
158 |
+
date_pattern = r"1[0-9]\d{3}(?:0[1-9]|1[0-2])(?:0[1-9]|[1-2]\d|3[01])(?:\d{4-6})?$" #抓年月日時分秒,懶得再修了
|
159 |
+
match = re.match(date_pattern, date_prediction[:7])
|
160 |
+
if match:
|
161 |
+
if date_prediction[:7] not in output['relations'][relation_label]['predictions']:
|
162 |
+
output['relations'][relation_label]['predictions'].append(date_prediction[:7])
|
163 |
+
else:
|
164 |
+
if date_prediction not in output['relations'][relation_label]['predictions']:
|
165 |
+
output['relations'][relation_label]['predictions'].append(date_prediction)
|
166 |
+
|
167 |
+
relation_label = '受理_終止日'
|
168 |
+
output['relations'].setdefault(relation_label, {
|
169 |
+
'relation': relation_label,
|
170 |
+
'predictions': []
|
171 |
+
})
|
172 |
+
date_pattern = r"1[0-9]\d{3}(?:0[1-9]|1[0-2])(?:0[1-9]|[1-2]\d|3[01])(?:\d{4-6})?$" #抓年月日時分秒,懶得再修了
|
173 |
+
match = re.match(date_pattern, date_prediction[:7])
|
174 |
+
if match:
|
175 |
+
if date_prediction[:7] not in output['relations'][relation_label]['predictions']:
|
176 |
+
output['relations'][relation_label]['predictions'].append(date_prediction[:7])
|
177 |
+
else:
|
178 |
+
if date_prediction not in output['relations'][relation_label]['predictions']:
|
179 |
+
output['relations'][relation_label]['predictions'].append(date_prediction)
|
180 |
+
|
181 |
+
|
182 |
+
elif relation1 == '手術項目':
|
183 |
+
relation_label = '手術項目'
|
184 |
+
output['relations'].setdefault(relation_label, {
|
185 |
+
'relation': relation_label,
|
186 |
+
'predictions': []
|
187 |
+
})
|
188 |
+
|
189 |
+
if date_prediction not in output['relations'][relation_label]['predictions']:
|
190 |
+
output['relations'][relation_label]['predictions'].append(date_prediction)
|
191 |
+
['predictions'].append(date_prediction)
|
192 |
+
|
193 |
+
if relation2:
|
194 |
+
predictions, topk_predictions = hugie.request(text, keyword, relation=relation2)
|
195 |
+
|
196 |
+
if predictions[0]: # 如果有預測結果
|
197 |
+
|
198 |
+
for prediction in predictions[0]:
|
199 |
+
date_prediction = convert_to_ROC_time(prediction)
|
200 |
+
|
201 |
+
if relation2 == '開始時間':
|
202 |
+
relation_label = '受理_起始日時分秒'
|
203 |
+
output['relations'][relation2] = {
|
204 |
+
'relation': relation_label,
|
205 |
+
'predictions': [date_prediction]
|
206 |
+
}
|
207 |
+
if relation2 == '終止時間':
|
208 |
+
relation_label = '受理_終止日時分秒'
|
209 |
+
output['relations'][relation2] = {
|
210 |
+
'relation': relation_label,
|
211 |
+
'predictions': [date_prediction]
|
212 |
+
}
|
213 |
+
|
214 |
+
existing_entities = [result['entity'] for result in results]
|
215 |
+
if output['entity'] in existing_entities:
|
216 |
+
# 合併相同實體的關係
|
217 |
+
existing_result = next((result for result in results if result['entity'] == output['entity']), None)
|
218 |
+
existing_relations = existing_result['relations']
|
219 |
+
for relation, predictions in output['relations'].items():
|
220 |
+
existing_relations[relation] = predictions
|
221 |
+
else:
|
222 |
+
results.append(output)
|
223 |
+
|
224 |
+
results = [result for result in results if result['relations']]
|
225 |
+
|
226 |
+
return json.dumps(results, indent=4, ensure_ascii=False)
|
227 |
+
|
228 |
+
title = "<p style='text-align: center'><a href='https://www.twman.org/AI/NLP' target='_blank'>醫囑分析:HugIE @ HugNLP</a>"
|
229 |
+
|
230 |
+
description = """
|
231 |
+
<p style='text-align: center'><a href="https://blog.twman.org/2023/07/HugIE.html" target='_blank'>基於機器閱讀理解(MRC)的指令微調(Instruction-tuning)的統一信息抽取框架之診斷書醫囑擷取分析</a></p><br>
|
232 |
+
<p style='text-align: center'><a href="https://github.com/Deep-Learning-101" target='_blank'>https://github.com/Deep-Learning-101</a></p><br>
|
233 |
+
<p style='text-align: center'><a href="https://github.com/Deep-Learning-101/Natural-Language-Processing-Paper" target='_blank'>https://github.com/Deep-Learning-101/Natural-Language-Processing-Paper</a></p><br>
|
234 |
+
"""
|
235 |
+
|
236 |
+
demo = gr.Interface(
|
237 |
+
fn=extract_information,
|
238 |
+
inputs=gr.components.Textbox(label="醫療診斷書之醫囑原始內容"),
|
239 |
+
outputs=gr.components.Textbox(label="醫療診斷書之醫囑擷取結果"),
|
240 |
+
examples = [
|
241 |
+
"患者因上述疾病,曾於112年02月13日12:15~112年02月13日13:43至本院急診治療,於112年02月13日轉灼傷中心普通病房,於112年02月17日接受傷口清創手術治療,於112年02月24日接受左上肢植皮重建手術治療,於112年03月03日轉出灼傷中心病房,於 112年03月09日病情穩定出院,曾於112年03月17日、112年03月21日、112年03月28日、112年04月07��、112年04月18日至本院門診治療,須穿著壓力衣避免疤痕增生,續門診追蹤",
|
242 |
+
"患者因甲狀腺乳突癌術後,依病歷記錄,患者接受王舒儀醫師於2023-03-29,郭仁富醫師於2023-05-02之本院門診追蹤治療,共計2次,並於2023-05-02至2023-05-03住院接受高劑量放射性碘隔離治療,現病況穩定予以出院,共計住院兩日,宜門診繼續追蹤治療。",
|
243 |
+
"1.患者因上述原因於202304-06在本院住院於2023-04-07施行開放性復位及鋼釘鋼板固定手術治療.術後應休養二個月患肢不宜提重物並使用手吊#六星明於2023-04-10計院續日診治蹤治療",
|
244 |
+
"病患曾於108-12-17至本院門診手術室接受右側經皮穿腎引留管換管手術治療,病患曾於108-12-17至本院門診治療",
|
245 |
+
"患者因上述原因曾於108年06月03日,12月06日,在本院門診接受子宮頸抹片追蹤檢查,建議返回長庚醫院後續癌症追蹤。",
|
246 |
+
"病人於民國108年09月14日從門診入院,住普通病房,於民國108年12月06日出院,特此證明。",
|
247 |
+
"該病患因上述疾病於民國108年5月18日至本院急診室就診,經傷口護理及診療後於當天出院,應於門診持續追蹤治療。",
|
248 |
+
"病人因上述症狀,於民國108年12月16日住院,接受自費欣普尼注射治療,並於民國108年12月17日出院,須門診追蹤治療。",
|
249 |
+
"該員於108年10月16日,因上述病情,入院施行治療,期間須使用呼吸器及氣墊床。於108年11月26日出院。",
|
250 |
+
"患肢不宜負重.宜休養3個月.宜使用三角巾固定.患者於民國108年01月23日至108年04月18日共至門診4次",
|
251 |
+
"病人因上述病症,於108年04月07日住入本院,接受支持性照護。108年04月10日出院於狀況穩定下予以出院。已安排後續放射線及化學治療。",
|
252 |
+
"病人因上述病情於108年05月25日入院至加護病房,於108年05月30日轉至普通病房,於108年06月03日出院。",
|
253 |
+
"病患曾於108年09月19日20:32~108年09月20日08:41至本院急診治療,於108年09月20日住院抗生素治療,108年09月26日出院.一週門診追蹤",
|
254 |
+
],
|
255 |
+
title=title,
|
256 |
+
description=description,
|
257 |
+
)
|
258 |
+
demo.launch(debug=True)
|
applications/information_extraction/HugIE/api_test.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
sys.path.append("./")
|
4 |
+
sys.path.append("../")
|
5 |
+
sys.path.append("../../")
|
6 |
+
sys.path.append("../../../")
|
7 |
+
from models import SPAN_EXTRACTION_MODEL_CLASSES
|
8 |
+
from models import TOKENIZER_CLASSES
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
|
12 |
+
|
13 |
+
class HugIEAPI:
|
14 |
+
def __init__(self, model_type, hugie_model_name_or_path) -> None:
|
15 |
+
if model_type not in SPAN_EXTRACTION_MODEL_CLASSES[
|
16 |
+
"global_pointer"].keys():
|
17 |
+
raise KeyError(
|
18 |
+
"You must choose one of the following model: {}".format(
|
19 |
+
", ".join(
|
20 |
+
list(SPAN_EXTRACTION_MODEL_CLASSES["global_pointer"].
|
21 |
+
keys()))))
|
22 |
+
self.model_type = model_type
|
23 |
+
self.model = SPAN_EXTRACTION_MODEL_CLASSES["global_pointer"][
|
24 |
+
self.model_type].from_pretrained(hugie_model_name_or_path)
|
25 |
+
self.tokenizer = TOKENIZER_CLASSES[self.model_type].from_pretrained(
|
26 |
+
hugie_model_name_or_path)
|
27 |
+
self.max_seq_length = 512
|
28 |
+
|
29 |
+
def fush_multi_answer(self, has_answer, new_answer):
|
30 |
+
# 对于某个id测试集,出现多个example时(例如同一个测试样本使用了多个模板而生成了多个example),此时将预测的topk结果进行合并
|
31 |
+
# has为已经合并的结果,new为当前新产生的结果,
|
32 |
+
# has格式为 {"ans": {"prob": float(prob[index_ids[ei]]), "pos": (s, e)}, ...}
|
33 |
+
# new {"ans": {"prob": float(prob[index_ids[ei]]), "pos": (s, e)}, ...}
|
34 |
+
# print("has_answer=", has_answer)
|
35 |
+
for ans, value in new_answer.items():
|
36 |
+
if ans not in has_answer.keys():
|
37 |
+
has_answer[ans] = value
|
38 |
+
else:
|
39 |
+
has_answer[ans]["prob"] += value["prob"]
|
40 |
+
has_answer[ans]["pos"].extend(value["pos"])
|
41 |
+
return has_answer
|
42 |
+
|
43 |
+
def get_predict_result(self, probs, indices, examples):
|
44 |
+
probs = probs.squeeze(1) # topk结果的概率
|
45 |
+
indices = indices.squeeze(1) # topk结果的索引
|
46 |
+
# print("probs=", probs) # [n, m]
|
47 |
+
# print("indices=", indices) # [n, m]
|
48 |
+
predictions = {}
|
49 |
+
topk_predictions = {}
|
50 |
+
idx = 0
|
51 |
+
for prob, index in zip(probs, indices):
|
52 |
+
index_ids = torch.Tensor([i for i in range(len(index))]).long()
|
53 |
+
topk_answer = list()
|
54 |
+
answer = []
|
55 |
+
topk_answer_dict = dict()
|
56 |
+
# TODO 1. 调节阈值 2. 处理输出实体重叠问题
|
57 |
+
entity_index = index[prob > 0.1]
|
58 |
+
index_ids = index_ids[prob > 0.1]
|
59 |
+
for ei, entity in enumerate(entity_index):
|
60 |
+
# 1D index转2D index
|
61 |
+
start_end = np.unravel_index(
|
62 |
+
entity, (self.max_seq_length, self.max_seq_length))
|
63 |
+
s = examples["offset_mapping"][idx][start_end[0]][0]
|
64 |
+
e = examples["offset_mapping"][idx][start_end[1]][1]
|
65 |
+
ans = examples["content"][idx][s:e]
|
66 |
+
if ans not in answer:
|
67 |
+
answer.append(ans)
|
68 |
+
# topk_answer.append({"answer": ans, "prob": float(prob[index_ids[ei]]), "pos": (s, e)})
|
69 |
+
topk_answer_dict[ans] = {
|
70 |
+
"prob":
|
71 |
+
float(prob[index_ids[ei]]),
|
72 |
+
"pos": [(s.detach().cpu().numpy().tolist(),
|
73 |
+
e.detach().cpu().numpy().tolist())]
|
74 |
+
}
|
75 |
+
|
76 |
+
predictions[idx] = answer
|
77 |
+
if idx not in topk_predictions.keys():
|
78 |
+
# print("topk_answer_dict=", topk_answer_dict)
|
79 |
+
topk_predictions[idx] = topk_answer_dict
|
80 |
+
else:
|
81 |
+
# print("topk_predictions[id_]=", topk_predictions[id_])
|
82 |
+
topk_predictions[idx] = self.fush_multi_answer(
|
83 |
+
topk_predictions[idx], topk_answer_dict)
|
84 |
+
idx += 1
|
85 |
+
|
86 |
+
for idx, values in topk_predictions.items():
|
87 |
+
# values {"ans": {}, ...}
|
88 |
+
answer_list = list()
|
89 |
+
for ans, value in values.items():
|
90 |
+
answer_list.append({
|
91 |
+
"answer": ans,
|
92 |
+
"prob": value["prob"],
|
93 |
+
"pos": value["pos"]
|
94 |
+
})
|
95 |
+
topk_predictions[idx] = answer_list
|
96 |
+
|
97 |
+
return predictions, topk_predictions
|
98 |
+
|
99 |
+
def request(self, text: str, entity_type: str, relation: str = None):
|
100 |
+
assert text is not None and entity_type is not None
|
101 |
+
if relation is None:
|
102 |
+
instruction = "找到文章中所有【{}】类型的实体?文章:【{}】".format(entity_type, text)
|
103 |
+
pre_len = 21 - 2 + len(entity_type)
|
104 |
+
else:
|
105 |
+
instruction = "找到文章中【{}】的【{}】?文章:【{}】".format(
|
106 |
+
entity_type, relation, text)
|
107 |
+
pre_len = 19 - 4 + len(entity_type) + len(relation)
|
108 |
+
|
109 |
+
inputs = self.tokenizer(instruction,
|
110 |
+
max_length=self.max_seq_length,
|
111 |
+
padding="max_length",
|
112 |
+
return_tensors="pt",
|
113 |
+
return_offsets_mapping=True)
|
114 |
+
|
115 |
+
examples = {
|
116 |
+
"content": [instruction],
|
117 |
+
"offset_mapping": inputs["offset_mapping"]
|
118 |
+
}
|
119 |
+
|
120 |
+
batch_input = {
|
121 |
+
"input_ids": inputs["input_ids"],
|
122 |
+
"token_type_ids": inputs["token_type_ids"],
|
123 |
+
"attention_mask": inputs["attention_mask"],
|
124 |
+
}
|
125 |
+
|
126 |
+
outputs = self.model(**batch_input)
|
127 |
+
|
128 |
+
probs, indices = outputs["topk_probs"], outputs["topk_indices"]
|
129 |
+
predictions, topk_predictions = self.get_predict_result(
|
130 |
+
probs, indices, examples=examples)
|
131 |
+
|
132 |
+
return predictions, topk_predictions
|
133 |
+
|
134 |
+
|
135 |
+
if __name__ == "__main__":
|
136 |
+
from applications.information_extraction.HugIE.api_test import HugIEAPI
|
137 |
+
model_type = "bert"
|
138 |
+
hugie_model_name_or_path = "wjn1996/wjn1996-hugnlp-hugie-large-zh"
|
139 |
+
hugie = HugIEAPI("bert", hugie_model_name_or_path)
|
140 |
+
text = "央广网北京2月23日消息 据中国地震台网正式测定,2月23日8时37分在塔吉克斯坦发生7.2级地震,震源深度10公里,震中位于北纬37.98度,东经73.29度,距我国边境线最近约82公里,地震造成新疆喀什等地震感强烈。"
|
141 |
+
|
142 |
+
## named entity recognition
|
143 |
+
entity_type = "国家"
|
144 |
+
predictions, topk_predictions = hugie.request(text, entity_type)
|
145 |
+
print("entity_type:{}".format(entity_type))
|
146 |
+
print("predictions:\n{}".format(predictions))
|
147 |
+
print("topk_predictions:\n{}".format(topk_predictions))
|
148 |
+
print("\n\n")
|
149 |
+
|
150 |
+
## event extraction
|
151 |
+
entity = "塔吉克斯坦地震"
|
152 |
+
relation = "震源深度"
|
153 |
+
predictions, topk_predictions = hugie.request(text,
|
154 |
+
entity,
|
155 |
+
relation=relation)
|
156 |
+
print("entity:{}, relation:{}".format(entity, relation))
|
157 |
+
print("predictions:\n{}".format(predictions))
|
158 |
+
print("topk_predictions:\n{}".format(topk_predictions))
|
159 |
+
print("\n\n")
|
160 |
+
|
161 |
+
## event extraction
|
162 |
+
entity = "塔吉克斯坦地震"
|
163 |
+
relation = "震源位置"
|
164 |
+
predictions, topk_predictions = hugie.request(text,
|
165 |
+
entity,
|
166 |
+
relation=relation)
|
167 |
+
print("entity:{}, relation:{}".format(entity, relation))
|
168 |
+
print("predictions:\n{}".format(predictions))
|
169 |
+
print("topk_predictions:\n{}".format(topk_predictions))
|
170 |
+
print("\n\n")
|
171 |
+
|
172 |
+
## event extraction
|
173 |
+
entity = "塔吉克斯坦地震"
|
174 |
+
relation = "时间"
|
175 |
+
predictions, topk_predictions = hugie.request(text,
|
176 |
+
entity,
|
177 |
+
relation=relation)
|
178 |
+
print("entity:{}, relation:{}".format(entity, relation))
|
179 |
+
print("predictions:\n{}".format(predictions))
|
180 |
+
print("topk_predictions:\n{}".format(topk_predictions))
|
181 |
+
print("\n\n")
|
182 |
+
|
183 |
+
## event extraction
|
184 |
+
entity = "塔吉克斯坦地震"
|
185 |
+
relation = "影响"
|
186 |
+
predictions, topk_predictions = hugie.request(text,
|
187 |
+
entity,
|
188 |
+
relation=relation)
|
189 |
+
print("entity:{}, relation:{}".format(entity, relation))
|
190 |
+
print("predictions:\n{}".format(predictions))
|
191 |
+
print("topk_predictions:\n{}".format(topk_predictions))
|
192 |
+
print("\n\n")
|
193 |
+
"""
|
194 |
+
Output results:
|
195 |
+
|
196 |
+
entity_type:国家
|
197 |
+
predictions:
|
198 |
+
{0: ["塔吉克斯坦"]}
|
199 |
+
predictions:
|
200 |
+
{0: [{"answer": "塔吉克斯坦", "prob": 0.9999997615814209, "pos": [(tensor(57), tensor(62))]}]}
|
201 |
+
|
202 |
+
|
203 |
+
|
204 |
+
entity:塔吉克斯坦地震, relation:震源深度
|
205 |
+
predictions:
|
206 |
+
{0: ["10公里"]}
|
207 |
+
predictions:
|
208 |
+
{0: [{"answer": "10公里", "prob": 0.999994158744812, "pos": [(tensor(80), tensor(84))]}]}
|
209 |
+
|
210 |
+
|
211 |
+
|
212 |
+
entity:塔吉克斯坦地震, relation:震源位置
|
213 |
+
predictions:
|
214 |
+
{0: ["10公里", "距我国边境线最近约82公里", "北纬37.98度,东经73.29度", "北纬37.98度,东经73.29度,距我国边境线最近约82公里"]}
|
215 |
+
predictions:
|
216 |
+
{0: [{"answer": "10公里", "prob": 0.9895901083946228, "pos": [(tensor(80), tensor(84))]}, {"answer": "距我国边境线最近约82公里", "prob": 0.8584909439086914, "pos": [(tensor(107), tensor(120))]}, {"answer": "北纬37.98度,东经73.29度", "prob": 0.7202121615409851, "pos": [(tensor(89), tensor(106))]}, {"answer": "北纬37.98度,东经73.29度,距我国边境线最近约82公里", "prob": 0.11628123372793198, "pos": [(tensor(89), tensor(120))]}]}
|
217 |
+
|
218 |
+
|
219 |
+
|
220 |
+
entity:塔吉克斯坦地震, relation:时间
|
221 |
+
predictions:
|
222 |
+
{0: ["2月23日8时37分"]}
|
223 |
+
predictions:
|
224 |
+
{0: [{"answer": "2月23日8时37分", "prob": 0.9999995231628418, "pos": [(tensor(49), tensor(59))]}]}
|
225 |
+
|
226 |
+
|
227 |
+
|
228 |
+
entity:塔吉克斯坦地震, relation:影响
|
229 |
+
predictions:
|
230 |
+
{0: ["新疆喀什等地震感强烈"]}
|
231 |
+
predictions:
|
232 |
+
{0: [{"answer": "新疆喀什等地震感强烈", "prob": 0.9525265693664551, "pos": [(tensor(123), tensor(133))]}]}
|
233 |
+
|
234 |
+
"""
|
models/__init__.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 2021/12/6 3:35 下午
|
3 |
+
# @Author : JianingWang
|
4 |
+
# @File : __init__.py
|
5 |
+
|
6 |
+
|
7 |
+
# from models.chid_mlm import BertForChidMLM
|
8 |
+
from models.multiple_choice.duma import BertDUMAForMultipleChoice, AlbertDUMAForMultipleChoice, MegatronDumaForMultipleChoice
|
9 |
+
from models.span_extraction.global_pointer import BertForEffiGlobalPointer, RobertaForEffiGlobalPointer, RoformerForEffiGlobalPointer, MegatronForEffiGlobalPointer
|
10 |
+
from transformers import AutoModelForTokenClassification, AutoModelForSequenceClassification, AutoModelForMaskedLM, AutoModelForMultipleChoice, BertTokenizer, \
|
11 |
+
AutoModelForQuestionAnswering, AutoModelForCausalLM
|
12 |
+
|
13 |
+
from transformers import AutoTokenizer
|
14 |
+
from transformers.models.roformer import RoFormerTokenizer
|
15 |
+
from transformers.models.bert import BertTokenizerFast, BertForTokenClassification, BertTokenizer
|
16 |
+
from transformers.models.roberta.tokenization_roberta import RobertaTokenizer
|
17 |
+
from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
|
18 |
+
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
19 |
+
from transformers.models.bart.tokenization_bart import BartTokenizer
|
20 |
+
from transformers.models.t5.tokenization_t5 import T5Tokenizer
|
21 |
+
from transformers.models.plbart.tokenization_plbart import PLBartTokenizer
|
22 |
+
|
23 |
+
|
24 |
+
# from models.deberta import DebertaV2ForMultipleChoice, DebertaForMultipleChoice
|
25 |
+
# from models.fengshen.models.longformer import LongformerForMultipleChoice
|
26 |
+
from models.kg import BertForPretrainWithKG, BertForPretrainWithKGV2
|
27 |
+
from models.language_modeling.mlm import BertForMaskedLM, RobertaForMaskedLM, AlbertForMaskedLM, RoFormerForMaskedLM
|
28 |
+
# from models.sequence_classification.classification import build_cls_model
|
29 |
+
from models.multiple_choice.multiple_choice_tag import BertForTagMultipleChoice, RoFormerForTagMultipleChoice, MegatronBertForTagMultipleChoice
|
30 |
+
from models.multiple_choice.multiple_choice import MegatronBertForMultipleChoice, MegatronBertRDropForMultipleChoice
|
31 |
+
from models.semeval7 import DebertaV2ForSemEval7MultiTask
|
32 |
+
from models.sequence_matching.fusion_siamese import BertForFusionSiamese, BertForWSC
|
33 |
+
# from roformer import RoFormerForTokenClassification, RoFormerForSequenceClassification
|
34 |
+
from models.fewshot_learning.span_proto import SpanProto
|
35 |
+
from models.fewshot_learning.token_proto import TokenProto
|
36 |
+
|
37 |
+
from models.sequence_labeling.head_token_cls import (
|
38 |
+
BertSoftmaxForSequenceLabeling, BertCrfForSequenceLabeling,
|
39 |
+
RobertaSoftmaxForSequenceLabeling, RobertaCrfForSequenceLabeling,
|
40 |
+
AlbertSoftmaxForSequenceLabeling, AlbertCrfForSequenceLabeling,
|
41 |
+
MegatronBertSoftmaxForSequenceLabeling, MegatronBertCrfForSequenceLabeling,
|
42 |
+
)
|
43 |
+
from models.span_extraction.span_for_ner import BertSpanForNer, RobertaSpanForNer, AlbertSpanForNer, MegatronBertSpanForNer
|
44 |
+
|
45 |
+
from models.language_modeling.mlm import BertForMaskedLM
|
46 |
+
from models.language_modeling.kpplm import BertForWikiKGPLM, RoBertaKPPLMForProcessedWikiKGPLM, DeBertaKPPLMForProcessedWikiKGPLM
|
47 |
+
from models.language_modeling.causal_lm import GPT2ForCausalLM
|
48 |
+
|
49 |
+
from models.sequence_classification.head_cls import (
|
50 |
+
BertForSequenceClassification, BertPrefixForSequenceClassification,
|
51 |
+
BertPtuningForSequenceClassification, BertAdapterForSequenceClassification,
|
52 |
+
RobertaForSequenceClassification, RobertaPrefixForSequenceClassification,
|
53 |
+
RobertaPtuningForSequenceClassification,RobertaAdapterForSequenceClassification,
|
54 |
+
BartForSequenceClassification, GPT2ForSequenceClassification
|
55 |
+
)
|
56 |
+
|
57 |
+
from models.sequence_classification.masked_prompt_cls import (
|
58 |
+
PromptBertForSequenceClassification, PromptBertPtuningForSequenceClassification,
|
59 |
+
PromptBertPrefixForSequenceClassification, PromptBertAdapterForSequenceClassification,
|
60 |
+
PromptRobertaForSequenceClassification, PromptRobertaPtuningForSequenceClassification,
|
61 |
+
PromptRobertaPrefixForSequenceClassification, PromptRobertaAdapterForSequenceClassification
|
62 |
+
)
|
63 |
+
|
64 |
+
from models.sequence_classification.causal_prompt_cls import PromptGPT2ForSequenceClassification
|
65 |
+
|
66 |
+
from models.code.code_classification import (
|
67 |
+
RobertaForCodeClassification, CodeBERTForCodeClassification,
|
68 |
+
GraphCodeBERTForCodeClassification, PLBARTForCodeClassification, CodeT5ForCodeClassification
|
69 |
+
)
|
70 |
+
from models.code.code_generation import (
|
71 |
+
PLBARTForCodeGeneration
|
72 |
+
)
|
73 |
+
|
74 |
+
from models.reinforcement_learning.actor import CausalActor
|
75 |
+
from models.reinforcement_learning.critic import AutoModelCritic
|
76 |
+
from models.reinforcement_learning.reward_model import (
|
77 |
+
RobertaForReward, GPT2ForReward
|
78 |
+
)
|
79 |
+
|
80 |
+
# Models for pre-training
|
81 |
+
PRETRAIN_MODEL_CLASSES = {
|
82 |
+
"mlm": {
|
83 |
+
"bert": BertForMaskedLM,
|
84 |
+
"roberta": RobertaForMaskedLM,
|
85 |
+
"albert": AlbertForMaskedLM,
|
86 |
+
"roformer": RoFormerForMaskedLM,
|
87 |
+
},
|
88 |
+
"auto_mlm": AutoModelForMaskedLM,
|
89 |
+
"causal_lm": {
|
90 |
+
"gpt2": GPT2ForCausalLM,
|
91 |
+
"bart": None,
|
92 |
+
"t5": None,
|
93 |
+
"llama": None
|
94 |
+
},
|
95 |
+
"auto_causal_lm": AutoModelForCausalLM
|
96 |
+
}
|
97 |
+
|
98 |
+
|
99 |
+
CLASSIFICATION_MODEL_CLASSES = {
|
100 |
+
"auto_cls": AutoModelForSequenceClassification, # huggingface cls
|
101 |
+
"classification": AutoModelForSequenceClassification, # huggingface cls
|
102 |
+
"head_cls": {
|
103 |
+
"bert": BertForSequenceClassification,
|
104 |
+
"roberta": RobertaForSequenceClassification,
|
105 |
+
"bart": BartForSequenceClassification,
|
106 |
+
"gpt2": GPT2ForSequenceClassification
|
107 |
+
}, # use standard fine-tuning head for cls, e.g., bert+mlp
|
108 |
+
"head_prefix_cls": {
|
109 |
+
"bert": BertPrefixForSequenceClassification,
|
110 |
+
"roberta": RobertaPrefixForSequenceClassification,
|
111 |
+
}, # use standard fine-tuning head with prefix-tuning technique for cls, e.g., bert+mlp
|
112 |
+
"head_ptuning_cls": {
|
113 |
+
"bert": BertPtuningForSequenceClassification,
|
114 |
+
"roberta": RobertaPtuningForSequenceClassification,
|
115 |
+
}, # use standard fine-tuning head with p-tuning technique for cls, e.g., bert+mlp
|
116 |
+
"head_adapter_cls": {
|
117 |
+
"bert": BertAdapterForSequenceClassification,
|
118 |
+
"roberta": RobertaAdapterForSequenceClassification,
|
119 |
+
}, # use standard fine-tuning head with adapter-tuning technique for cls, e.g., bert+mlp
|
120 |
+
"masked_prompt_cls": {
|
121 |
+
"bert": PromptBertForSequenceClassification,
|
122 |
+
"roberta": PromptRobertaForSequenceClassification,
|
123 |
+
# "deberta": PromptDebertaForSequenceClassification,
|
124 |
+
# "deberta-v2": PromptDebertav2ForSequenceClassification,
|
125 |
+
}, # use masked lm head technique for prompt-based cls, e.g., bert+mlm
|
126 |
+
"masked_prompt_prefix_cls": {
|
127 |
+
"bert": PromptBertPrefixForSequenceClassification,
|
128 |
+
"roberta": PromptRobertaPrefixForSequenceClassification,
|
129 |
+
# "deberta": PromptDebertaPrefixForSequenceClassification,
|
130 |
+
# "deberta-v2": PromptDebertav2PrefixForSequenceClassification,
|
131 |
+
}, # use masked lm head with prefix-tuning technique for prompt-based cls, e.g., bert+mlm
|
132 |
+
"masked_prompt_ptuning_cls": {
|
133 |
+
"bert": PromptBertPtuningForSequenceClassification,
|
134 |
+
"roberta": PromptRobertaPtuningForSequenceClassification,
|
135 |
+
# "deberta": PromptDebertaPtuningForSequenceClassification,
|
136 |
+
# "deberta-v2": PromptDebertav2PtuningForSequenceClassification,
|
137 |
+
}, # use masked lm head with p-tuning technique for prompt-based cls, e.g., bert+mlm
|
138 |
+
"masked_prompt_adapter_cls": {
|
139 |
+
"bert": PromptBertAdapterForSequenceClassification,
|
140 |
+
"roberta": PromptRobertaAdapterForSequenceClassification,
|
141 |
+
}, # use masked lm head with adapter-tuning technique for prompt-based cls, e.g., bert+mlm
|
142 |
+
"causal_prompt_cls": {
|
143 |
+
"gpt2": PromptGPT2ForSequenceClassification,
|
144 |
+
"bart": None,
|
145 |
+
"t5": None,
|
146 |
+
}, # use causal lm head for prompt-tuning, e.g., gpt2+lm
|
147 |
+
}
|
148 |
+
|
149 |
+
|
150 |
+
TOKEN_CLASSIFICATION_MODEL_CLASSES = {
|
151 |
+
"auto_token_cls": AutoModelForTokenClassification,
|
152 |
+
"head_softmax_token_cls": {
|
153 |
+
"bert": BertSoftmaxForSequenceLabeling,
|
154 |
+
"roberta": RobertaSoftmaxForSequenceLabeling,
|
155 |
+
"albert": AlbertSoftmaxForSequenceLabeling,
|
156 |
+
"megatron": MegatronBertSoftmaxForSequenceLabeling,
|
157 |
+
},
|
158 |
+
"head_crf_token_cls": {
|
159 |
+
"bert": BertCrfForSequenceLabeling,
|
160 |
+
"roberta": RobertaCrfForSequenceLabeling,
|
161 |
+
"albert": AlbertCrfForSequenceLabeling,
|
162 |
+
"megatron": MegatronBertCrfForSequenceLabeling,
|
163 |
+
}
|
164 |
+
}
|
165 |
+
|
166 |
+
|
167 |
+
SPAN_EXTRACTION_MODEL_CLASSES = {
|
168 |
+
"global_pointer": {
|
169 |
+
"bert": BertForEffiGlobalPointer,
|
170 |
+
"roberta": RobertaForEffiGlobalPointer,
|
171 |
+
"roformer": RoformerForEffiGlobalPointer,
|
172 |
+
"megatronbert": MegatronForEffiGlobalPointer
|
173 |
+
},
|
174 |
+
}
|
175 |
+
|
176 |
+
|
177 |
+
FEWSHOT_MODEL_CLASSES = {
|
178 |
+
"sequence_proto": None,
|
179 |
+
"span_proto": SpanProto,
|
180 |
+
"token_proto": TokenProto,
|
181 |
+
}
|
182 |
+
|
183 |
+
|
184 |
+
CODE_MODEL_CLASSES = {
|
185 |
+
"code_cls": {
|
186 |
+
"roberta": RobertaForCodeClassification,
|
187 |
+
"codebert": CodeBERTForCodeClassification,
|
188 |
+
"graphcodebert": GraphCodeBERTForCodeClassification,
|
189 |
+
"codet5": CodeT5ForCodeClassification,
|
190 |
+
"plbart": PLBARTForCodeClassification,
|
191 |
+
},
|
192 |
+
"code_generation": {
|
193 |
+
# "roberta": RobertaForCodeGeneration,
|
194 |
+
# "codebert": BertForCodeGeneration,
|
195 |
+
# "graphcodebert": BertForCodeGeneration,
|
196 |
+
# "codet5": T5ForCodeGeneration,
|
197 |
+
"plbart": PLBARTForCodeGeneration,
|
198 |
+
},
|
199 |
+
}
|
200 |
+
|
201 |
+
REINFORCEMENT_MODEL_CLASSES = {
|
202 |
+
"causal_actor": CausalActor,
|
203 |
+
"auto_critic": AutoModelCritic,
|
204 |
+
"rl_reward": {
|
205 |
+
"roberta": RobertaForReward,
|
206 |
+
"gpt2": GPT2ForReward,
|
207 |
+
"gpt-neo": None,
|
208 |
+
"opt": None,
|
209 |
+
"llama": None,
|
210 |
+
}
|
211 |
+
}
|
212 |
+
|
213 |
+
# task_type 负责对应model类型
|
214 |
+
OTHER_MODEL_CLASSES = {
|
215 |
+
# sequence labeling
|
216 |
+
"bert_span_ner": BertSpanForNer,
|
217 |
+
"roberta_span_ner": RobertaSpanForNer,
|
218 |
+
"albert_span_ner": AlbertSpanForNer,
|
219 |
+
"megatronbert_span_ner": MegatronBertSpanForNer,
|
220 |
+
# sequence matching
|
221 |
+
"fusion_siamese": BertForFusionSiamese,
|
222 |
+
# multiple choice
|
223 |
+
"multi_choice": AutoModelForMultipleChoice,
|
224 |
+
"multi_choice_megatron": MegatronBertForMultipleChoice,
|
225 |
+
"multi_choice_megatron_rdrop": MegatronBertRDropForMultipleChoice,
|
226 |
+
"megatron_multi_choice_tag": MegatronBertForTagMultipleChoice,
|
227 |
+
"roformer_multi_choice_tag": RoFormerForTagMultipleChoice,
|
228 |
+
"multi_choice_tag": BertForTagMultipleChoice,
|
229 |
+
"duma": BertDUMAForMultipleChoice,
|
230 |
+
"duma_albert": AlbertDUMAForMultipleChoice,
|
231 |
+
"duma_megatron": MegatronDumaForMultipleChoice,
|
232 |
+
# language modeling
|
233 |
+
|
234 |
+
# "bert_mlm_acc": BertForMaskedLMWithACC,
|
235 |
+
# "roformer_mlm_acc": RoFormerForMaskedLMWithACC,
|
236 |
+
"bert_pretrain_kg": BertForPretrainWithKG,
|
237 |
+
"bert_pretrain_kg_v2": BertForPretrainWithKGV2,
|
238 |
+
"kpplm_roberta": RoBertaKPPLMForProcessedWikiKGPLM,
|
239 |
+
"kpplm_deberta": DeBertaKPPLMForProcessedWikiKGPLM,
|
240 |
+
|
241 |
+
# other
|
242 |
+
"clue_wsc": BertForWSC,
|
243 |
+
"semeval7multitask": DebertaV2ForSemEval7MultiTask,
|
244 |
+
# "debertav2_multi_choice": DebertaV2ForMultipleChoice,
|
245 |
+
# "deberta_multi_choice": DebertaForMultipleChoice,
|
246 |
+
# "qa": AutoModelForQuestionAnswering,
|
247 |
+
# "roformer_cls": RoFormerForSequenceClassification,
|
248 |
+
# "roformer_ner": RoFormerForTokenClassification,
|
249 |
+
# "fensheng_multi_choice": LongformerForMultipleChoice,
|
250 |
+
# "chid_mlm": BertForChidMLM,
|
251 |
+
}
|
252 |
+
|
253 |
+
|
254 |
+
# MODEL_CLASSES = dict(list(PRETRAIN_MODEL_CLASSES.items()) + list(OTHER_MODEL_CLASSES.items()))
|
255 |
+
MODEL_CLASSES_LIST = [
|
256 |
+
PRETRAIN_MODEL_CLASSES,
|
257 |
+
CLASSIFICATION_MODEL_CLASSES,
|
258 |
+
TOKEN_CLASSIFICATION_MODEL_CLASSES,
|
259 |
+
SPAN_EXTRACTION_MODEL_CLASSES,
|
260 |
+
FEWSHOT_MODEL_CLASSES,
|
261 |
+
CODE_MODEL_CLASSES,
|
262 |
+
REINFORCEMENT_MODEL_CLASSES,
|
263 |
+
OTHER_MODEL_CLASSES,
|
264 |
+
]
|
265 |
+
|
266 |
+
|
267 |
+
MODEL_CLASSES = dict()
|
268 |
+
for model_class in MODEL_CLASSES_LIST:
|
269 |
+
MODEL_CLASSES = dict(list(MODEL_CLASSES.items()) + list(model_class.items()))
|
270 |
+
|
271 |
+
# model_type 负责对应tokenizer
|
272 |
+
TOKENIZER_CLASSES = {
|
273 |
+
# for natural language processing
|
274 |
+
"auto": AutoTokenizer,
|
275 |
+
"bert": BertTokenizerFast,
|
276 |
+
"roberta": RobertaTokenizer,
|
277 |
+
"wobert": RoFormerTokenizer,
|
278 |
+
"roformer": RoFormerTokenizer,
|
279 |
+
"bigbird": BertTokenizerFast,
|
280 |
+
"erlangshen": BertTokenizerFast,
|
281 |
+
"deberta": BertTokenizer,
|
282 |
+
"roformer_v2": BertTokenizerFast,
|
283 |
+
"gpt2": GPT2Tokenizer,
|
284 |
+
"megatronbert": BertTokenizerFast,
|
285 |
+
"bart": BartTokenizer,
|
286 |
+
"t5": T5Tokenizer,
|
287 |
+
# for programming language processing
|
288 |
+
"codebert": RobertaTokenizer,
|
289 |
+
"graphcodebert": RobertaTokenizer,
|
290 |
+
"codet5": RobertaTokenizer,
|
291 |
+
"plbart": PLBartTokenizer
|
292 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
jieba
|
2 |
+
roformer
|
3 |
+
scikit-learn
|
4 |
+
sentence-transformers
|
5 |
+
sentencepiece
|
6 |
+
torch==1.12.1
|
7 |
+
transformers==4.21.2
|
8 |
+
tqdm
|
9 |
+
ujson
|
10 |
+
gradio==2.3.0
|
11 |
+
gradio_client==0.2.7
|
wjn1996-hugnlp-hugie-large-zh/config.json
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"RoPE": true,
|
3 |
+
"_name_or_path": "/wjn/projects/information_extraction/HugIE/outputs/zh_instruction/chinese-macbert-large/chinese-macbert-large",
|
4 |
+
"architectures": [
|
5 |
+
"BertForEffiGlobalPointer"
|
6 |
+
],
|
7 |
+
"attention_probs_dropout_prob": 0.1,
|
8 |
+
"bos_token_id": 0,
|
9 |
+
"classifier_dropout": null,
|
10 |
+
"directionality": "bidi",
|
11 |
+
"ent_type_size": 1,
|
12 |
+
"eos_token_id": 2,
|
13 |
+
"finetuning_task": "laic",
|
14 |
+
"hidden_act": "gelu",
|
15 |
+
"hidden_dropout_prob": 0.1,
|
16 |
+
"hidden_size": 1024,
|
17 |
+
"initializer_range": 0.02,
|
18 |
+
"inner_dim": 64,
|
19 |
+
"intermediate_size": 4096,
|
20 |
+
"layer_norm_eps": 1e-12,
|
21 |
+
"max_position_embeddings": 512,
|
22 |
+
"model_type": "bert",
|
23 |
+
"num_attention_heads": 16,
|
24 |
+
"num_hidden_layers": 24,
|
25 |
+
"output_past": true,
|
26 |
+
"pad_token_id": 0,
|
27 |
+
"pooler_fc_size": 768,
|
28 |
+
"pooler_num_attention_heads": 12,
|
29 |
+
"pooler_num_fc_layers": 3,
|
30 |
+
"pooler_size_per_head": 128,
|
31 |
+
"pooler_type": "first_token_transform",
|
32 |
+
"position_embedding_type": "absolute",
|
33 |
+
"torch_dtype": "float32",
|
34 |
+
"transformers_version": "4.21.2",
|
35 |
+
"type_vocab_size": 2,
|
36 |
+
"use_cache": true,
|
37 |
+
"vocab_size": 21128
|
38 |
+
}
|
wjn1996-hugnlp-hugie-large-zh/gitattributes.txt
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
wjn1996-hugnlp-hugie-large-zh/special_tokens_map.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": "[CLS]",
|
3 |
+
"mask_token": "[MASK]",
|
4 |
+
"pad_token": "[PAD]",
|
5 |
+
"sep_token": "[SEP]",
|
6 |
+
"unk_token": "[UNK]"
|
7 |
+
}
|
wjn1996-hugnlp-hugie-large-zh/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
wjn1996-hugnlp-hugie-large-zh/tokenizer_config.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": "[CLS]",
|
3 |
+
"do_lower_case": true,
|
4 |
+
"mask_token": "[MASK]",
|
5 |
+
"name_or_path": "/wjn/projects/information_extraction/HugIE/outputs/zh_instruction/chinese-macbert-large/chinese-macbert-large",
|
6 |
+
"pad_token": "[PAD]",
|
7 |
+
"sep_token": "[SEP]",
|
8 |
+
"special_tokens_map_file": "/wjn/pre-trained-lm/chinese-macbert-large/special_tokens_map.json",
|
9 |
+
"strip_accents": null,
|
10 |
+
"tokenize_chinese_chars": true,
|
11 |
+
"tokenizer_class": "BertTokenizer",
|
12 |
+
"unk_token": "[UNK]",
|
13 |
+
"use_fast": true
|
14 |
+
}
|
wjn1996-hugnlp-hugie-large-zh/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|