cheesexuebao commited on
Commit
f845b05
1 Parent(s): b6be546

standard repo for pre.

Browse files
app.py CHANGED
@@ -3,19 +3,7 @@ import pandas as pd
3
  from Prediction import *
4
  import os
5
  from datetime import datetime
6
- import re
7
- import json
8
- import hashlib
9
 
10
- persistent_path = "/output"
11
- # os.environ['HF_HOME'] = os.path.join(persistent_path, ".huggingface")
12
- user_input_path = os.path.join(persistent_path, 'user.jsonl')
13
- secret = "2fc9ff032e027e8f23bb9fb693234899"
14
-
15
- def get_md5(s):
16
- md = hashlib.md5()
17
- md.update(s.encode('utf-8'))
18
- return md.hexdigest()
19
 
20
  examples = []
21
  if os.path.exists("assets/examples.txt"):
@@ -53,72 +41,6 @@ def csv_process(csv_file, attr="content"):
53
  outputs.append(output_path)
54
  return outputs
55
 
56
- def logfile_query(auth):
57
- if get_md5(auth) == secret and os.path.exists(user_input_path):
58
- return [user_input_path]
59
- else:
60
- return None
61
-
62
-
63
- def check_save(fname, lname, cnum, email, oname, position):
64
- errors = []
65
- valid_vars = {}
66
-
67
- if not fname.strip() or not lname.strip():
68
- errors.append("Name cannot be empty")
69
- elif fname.isdigit() or lname.isdigit():
70
- errors.append("Name cannot be purely numerical")
71
- else:
72
- valid_vars["fname"] = fname
73
- valid_vars["lname"] = lname
74
-
75
- valid_vars["cnum"] = ''
76
- if cnum:
77
- if not cnum.isdigit():
78
- errors.append("The phone number must be a pure number")
79
- else:
80
- valid_vars["cnum"] = cnum
81
-
82
- if not email.strip():
83
- errors.append("Email cannot be empty")
84
- elif not re.match(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$', email):
85
- errors.append("Incorrect email format")
86
- else:
87
- valid_vars["email"] = email
88
-
89
- if not oname.strip():
90
- errors.append("Organization name cannot be empty")
91
- elif oname.isdigit():
92
- errors.append("Organization cannot be purely numerical")
93
- else:
94
- valid_vars["oname"] = oname
95
-
96
- valid_vars["position"] = ''
97
- if position:
98
- if position.isdigit():
99
- errors.append("Position in your company cannot be purely numerical")
100
- else:
101
- valid_vars["position"] = position
102
-
103
- if errors:
104
- return errors
105
-
106
- current_time = datetime.now()
107
- formatted_time = current_time.strftime("%Y_%m_%d_%H_%M_%S")
108
- valid_vars['time'] = formatted_time
109
-
110
- with open(user_input_path, 'a+', encoding="utf8") as file:
111
- file.write(json.dumps(valid_vars)+"\n")
112
-
113
- records = {}
114
- with open(user_input_path, 'r', encoding="utf8") as file:
115
- for line in file:
116
- line = line.strip()
117
- dct = json.loads(line)
118
- records[dct['time']] = dct
119
-
120
- return records
121
-
122
 
123
  my_theme = gr.Theme.from_hub("JohnSmith9982/small_and_pretty")
124
  with gr.Blocks(theme=my_theme, title='Brand_Tone_of_Voice_demo') as demo:
@@ -138,116 +60,66 @@ with gr.Blocks(theme=my_theme, title='Brand_Tone_of_Voice_demo') as demo:
138
  </div>
139
  </div>
140
  """)
141
-
142
- with gr.Column(visible=True) as regis:
143
- gr.Markdown("# Welcome to BTV! Please fill out the form below to continue.\nI’m assuming that you mention somewhere that this project/research is conducted by the University of Manchester/AMBS. By ticking this box, I consent to be approached by the research team of the University of Manchester.")
144
- with gr.Column(variant='panel'):
145
- fname_tb = gr.Textbox(label="First Name: ", type='text')
146
- lname_tb = gr.Textbox(label="Last Name: ", type='text')
147
- email_tb = gr.Textbox(label="Email: ", type='email')
148
- cnum_tb = gr.Textbox(label="Contact: (Optional)", type='text')
149
- oname_tb = gr.Textbox(label="Organization name: ", type='text')
150
- position_tb = gr.Textbox(label="Positions in your company: (Optional)", type='text')
151
- error_box = gr.HTML(value="", visible=False)
152
- submit_btn = gr.Button("Click here to start if you have fullfill all the item!")
153
-
154
- with gr.Row(visible=False) as mainrow:
155
-
156
- with gr.Tab("Single Sentence"):
157
- with gr.Row():
158
- tbox_input = gr.Textbox(label="Input",
159
- info="Please input a sentence here:")
160
- gr.Markdown("""
161
- # Detailed information about our model:
162
- ...
163
- """)
164
- tab_output = gr.DataFrame(label='Predictions:',
165
- headers=["Label", "Probability"],
166
- datatype=["str", "number"],
167
- interactive=False)
168
- with gr.Row():
169
- button_ss = gr.Button("Submit", variant="primary")
170
- button_ss.click(fn=single_sentence, inputs=[tbox_input], outputs=[tab_output])
171
- gr.ClearButton([tbox_input, tab_output])
172
-
173
- gr.Examples(
174
- examples=examples,
175
- inputs=tbox_input,
176
- examples_per_page=len(examples)
177
- )
178
-
179
- with gr.Tab("Csv File"):
180
- with gr.Row():
181
- csv_input = gr.File(label="CSV File:",
182
- file_types=['.csv'],
183
- file_count="single"
184
- )
185
- csv_output = gr.File(label="Predictions:")
186
-
187
- with gr.Row():
188
- button_cf = gr.Button("Submit", variant="primary")
189
- button_cf.click(fn=csv_process, inputs=[csv_input], outputs=[csv_output])
190
- gr.ClearButton([csv_input, csv_output])
191
-
192
- gr.Markdown("## Examples \n The incoming CSV must include the ``content`` field, which represents the text that needs to be predicted!")
193
- gr.DataFrame(label='Csv input format:',
194
- value=[[i, examples[i]] for i in range(len(examples))],
195
- headers=["index", "content"],
196
- datatype=["number","str"],
197
- interactive=False
198
- )
199
-
200
- with gr.Tab("Readme"):
201
- gr.Markdown(
202
- """
203
- # Paper Name
204
-
205
- # Authors
206
-
207
- + First author
208
- + Corresponding author
209
-
210
- # Detailed Information
211
-
212
- ...
213
- """
214
- )
215
-
216
- with gr.Tab("Log File"):
217
- with gr.Row():
218
- auth_token = gr.Textbox(label="Authentication Tokens: ", info="Enter the key to download persistent stored log information.")
219
- log_output = gr.File(label="Log file: ")
220
-
221
- with gr.Row():
222
- button_lf = gr.Button("Validate", variant="primary")
223
- button_lf.click(fn=logfile_query, inputs=[auth_token], outputs=[log_output])
224
- gr.ClearButton([auth_token, log_output])
225
-
226
-
227
- def submit(*user_input):
228
- res = check_save(*user_input)
229
- if isinstance(res, list):
230
- return {
231
- error_box: gr.HTML(
232
- value=f"""
233
- <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
234
- <div>
235
- <p style="color:red;">{"; ".join(res)}</p>
236
- </div>
237
- </div>
238
- """,
239
- visible=True)
240
- }
241
- else:
242
- return {
243
- mainrow: gr.Row(visible=True),
244
- regis: gr.Row(visible=False),
245
- error_box: gr.HTML(visible=False)
246
- }
247
-
248
- submit_btn.click(
249
- submit,
250
- [fname_tb, lname_tb, cnum_tb, email_tb, oname_tb, position_tb],
251
- [mainrow, regis, error_box],
252
- )
253
  demo.launch()
 
3
  from Prediction import *
4
  import os
5
  from datetime import datetime
 
 
 
6
 
 
 
 
 
 
 
 
 
 
7
 
8
  examples = []
9
  if os.path.exists("assets/examples.txt"):
 
41
  outputs.append(output_path)
42
  return outputs
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  my_theme = gr.Theme.from_hub("JohnSmith9982/small_and_pretty")
46
  with gr.Blocks(theme=my_theme, title='Brand_Tone_of_Voice_demo') as demo:
 
60
  </div>
61
  </div>
62
  """)
63
+
64
+ with gr.Tab("Readme"):
65
+ gr.Markdown("""
66
+ # Detailed information about our model:
67
+
68
+ The example model here is a tone classification model suitable for financial field texts.
69
+
70
+ # Paper Name
71
+
72
+ # Authors
73
+
74
+ + First author
75
+ + Corresponding author
76
+
77
+ # How to use?
78
+
79
+ Please refer to the other two tab card for predictions.
80
+
81
+ + The `Single Sentence` for the tone classification of individual sentence.
82
+ + The `CSV File` for inputting CSV file for batch prediction and return.
83
+ ...
84
+ """)
85
+
86
+ with gr.Tab("Single Sentence"):
87
+ tbox_input = gr.Textbox(label="Input",
88
+ info="Please input a sentence here:")
89
+
90
+ tab_output = gr.DataFrame(label='Predictions:',
91
+ headers=["Category", "Probability"],
92
+ datatype=["str", "number"],
93
+ interactive=False)
94
+ with gr.Row():
95
+ button_ss = gr.Button("Submit", variant="primary")
96
+ button_ss.click(fn=single_sentence, inputs=[tbox_input], outputs=[tab_output])
97
+ gr.ClearButton([tbox_input, tab_output])
98
+
99
+ gr.Examples(
100
+ examples=examples,
101
+ inputs=tbox_input,
102
+ examples_per_page=len(examples)
103
+ )
104
+
105
+ with gr.Tab("Csv File"):
106
+ with gr.Row():
107
+ csv_input = gr.File(label="CSV File:",
108
+ file_types=['.csv'],
109
+ file_count="single"
110
+ )
111
+ csv_output = gr.File(label="Predictions:")
112
+
113
+ with gr.Row():
114
+ button = gr.Button("Submit", variant="primary")
115
+ button.click(fn=csv_process, inputs=[csv_input], outputs=[csv_output])
116
+ gr.ClearButton([csv_input, csv_output])
117
+
118
+ gr.Markdown("## Examples \n The incoming CSV must include the ``content`` field, which represents the text that needs to be predicted!")
119
+ gr.DataFrame(label='Csv input format:',
120
+ value=[[i, examples[i]] for i in range(len(examples))],
121
+ headers=["index", "content"],
122
+ datatype=["number","str"],
123
+ interactive=False
124
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  demo.launch()
assets/Kickstarter_sentence_level_5000.csv DELETED
The diff for this file is too large to render. See raw diff
 
assets/Prediction.py.bak DELETED
@@ -1,132 +0,0 @@
1
- ### install the needed package
2
- # !pip install transformers
3
- # !pip install torchmetrics
4
- # !pip3 install ogb pytorch_lightning -q
5
-
6
-
7
-
8
- import pandas as pd
9
- from tqdm.auto import tqdm
10
- import torch
11
- import torch.nn as nn
12
- from torch.utils.data import DataLoader, Dataset
13
- from transformers import BertTokenizerFast as BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
14
- # import pytorch_lightning as pl
15
-
16
- pd.set_option('display.max_columns', 500)
17
-
18
- RANDOM_SEED = 42
19
-
20
-
21
- class ModelTagger(nn.Module):
22
- def __init__(self, model_path="bert-base-uncased"):
23
- super().__init__()
24
-
25
- self.bert = BertModel.from_pretrained(model_path, return_dict=True)
26
- self.classifier = nn.Linear(self.bert.config.hidden_size, 4)
27
- self.criterion = nn.BCELoss()
28
-
29
-
30
- def forward(self, input_ids, attention_mask, labels=None):
31
-
32
- output = self.bert(input_ids, attention_mask=attention_mask)
33
- output = self.classifier(output.pooler_output)
34
- output = torch.sigmoid(output)
35
- loss = 0
36
-
37
- if labels is not None:
38
- loss = self.criterion(output, labels)
39
- return loss, output
40
-
41
-
42
- class Predict_Dataset(Dataset):
43
- def __init__(
44
- self,
45
- data: pd.DataFrame,
46
- text_col: str,
47
- tokenizer: BertTokenizer,
48
- max_token_len: int = 128
49
- ):
50
- self.text_col = text_col
51
- self.tokenizer = tokenizer
52
- self.data = data
53
- self.max_token_len = max_token_len
54
-
55
- def __len__(self):
56
- return len(self.data)
57
-
58
-
59
- def __getitem__(self, index: int):
60
- data_row = self.data.iloc[index]
61
- post = data_row[self.text_col]
62
- encoding = self.tokenizer.encode_plus(
63
- post,
64
- add_special_tokens=True,
65
- max_length=self.max_token_len,
66
- return_token_type_ids=False,
67
- padding="max_length",
68
- truncation=True,
69
- return_attention_mask=True,
70
- return_tensors='pt',
71
- )
72
- return dict(
73
- post=post,
74
- input_ids=encoding["input_ids"].flatten(),
75
- attention_mask=encoding["attention_mask"].flatten(),
76
- )
77
-
78
-
79
- def predict(data, text_col, tokenizer, model, device, LABEL_COLUMNS, max_token_len=128):
80
- predictions = []
81
-
82
- df_token = Predict_Dataset(data, text_col, tokenizer, max_token_len=max_token_len)
83
- loader = DataLoader(df_token, batch_size=1000, num_workers=0)
84
-
85
- for item in tqdm(loader):
86
- _, prediction = model(
87
- item["input_ids"].to(device),
88
- item["attention_mask"].to(device)
89
- )
90
- predictions.append(prediction.detach().cpu())
91
-
92
- final_pred = torch.cat(predictions, dim=0)
93
- y_inten = final_pred.numpy().T
94
-
95
- return {
96
- LABEL_COLUMNS[0]: y_inten[0].tolist(),
97
- LABEL_COLUMNS[1]: y_inten[1].tolist(),
98
- LABEL_COLUMNS[2]: y_inten[2].tolist(),
99
- LABEL_COLUMNS[3]: y_inten[3].tolist()
100
- }
101
-
102
-
103
- def get_result(df, result, LABEL_COLUMNS):
104
- df[LABEL_COLUMNS[0]] = result[LABEL_COLUMNS[0]]
105
- df[LABEL_COLUMNS[1]] = result[LABEL_COLUMNS[1]]
106
- df[LABEL_COLUMNS[2]] = result[LABEL_COLUMNS[2]]
107
- df[LABEL_COLUMNS[3]] = result[LABEL_COLUMNS[3]]
108
- return df
109
-
110
-
111
- Data = pd.read_csv("Kickstarter_sentence_level_5000.csv")
112
- Data = Data[:20]
113
- device = torch.device('cpu')
114
- BERT_MODEL_NAME = 'bert-base-uncased'
115
- tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
116
- LABEL_COLUMNS = ["Assertive Tone", "Conversational Tone", "Emotional Tone", "Informative Tone"]
117
-
118
- params = torch.load("checkpoints/Kickstarter.ckpt", map_location='cpu')['state_dict']
119
- kick_model = ModelTagger()
120
- kick_model.load_state_dict(params, strict=True)
121
- kick_model.eval()
122
-
123
- kick_model = kick_model.to(device)
124
-
125
- kick_fk_doc_result = predict(Data,"content", tokenizer,kick_model, device, LABEL_COLUMNS)
126
-
127
- fk_result = get_result(Data, kick_fk_doc_result, LABEL_COLUMNS)
128
-
129
- fk_result.to_csv("output/prediction_origin_Kickstarter.csv")
130
-
131
-
132
- # tab_output = gr.Label(label='Probability Predictions:', value=dict(zip(LABEL_COLUMNS, [0]*len(LABEL_COLUMNS))))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
convert.py DELETED
@@ -1,30 +0,0 @@
1
- import torch
2
- import glob
3
- import os
4
- from transformers import BertTokenizerFast as BertTokenizer, BertForSequenceClassification
5
-
6
- os.environ['https_proxy'] = "127.0.0.1:1081"
7
-
8
- LABEL_COLUMNS = ["Assertive Tone", "Conversational Tone", "Emotional Tone", "Informative Tone", "None"]
9
-
10
- tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
11
- model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=5)
12
- id2label = {i:label for i,label in enumerate(LABEL_COLUMNS)}
13
- label2id = {label:i for i,label in enumerate(LABEL_COLUMNS)}
14
-
15
- for ckpt in glob.glob('checkpoints/*.ckpt'):
16
- base_name = os.path.basename(ckpt)
17
- # 去除文件后缀
18
- model_name = os.path.splitext(base_name)[0]
19
- params = torch.load(ckpt, map_location="cpu")['state_dict']
20
- msg = model.load_state_dict(params, strict=True)
21
- path = f'models/{model_name}'
22
- os.makedirs(path, exist_ok=True)
23
-
24
- torch.save(model.state_dict(), f'{path}/pytorch_model.bin')
25
- config = model.config
26
- config.architectures = ['BertForSequenceClassification']
27
- config.label2id = label2id
28
- config.id2label = id2label
29
- model.config.to_json_file(f'{path}/config.json')
30
- tokenizer.save_vocabulary(path)