Peter commited on
Commit
74b8229
1 Parent(s): 2e158ce

:tada: init from template

Browse files
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # basics
2
+ *__pycache__*
3
+
4
+ # local testing
5
+ *aitextgen*
6
+ *scratch*
7
+ *tmp*
8
+
9
+ # gradio database files
10
+ *gradio_db_files*
11
+ *gradio*
12
+ *flagged*
app.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ app.py - the main file for the app. This creates the flask app and handles the routes.
3
+
4
+ """
5
+
6
+ import torch
7
+ from transformers import pipeline
8
+ from cleantext import clean
9
+ from pathlib import Path
10
+ import warnings
11
+ import time
12
+ import argparse
13
+ import logging
14
+ import gradio as gr
15
+ import os
16
+ import sys
17
+ from os.path import dirname
18
+ import nltk
19
+ from converse import discussion
20
+ from grammar_improve import (
21
+ detect_propers,
22
+ load_ns_checker,
23
+ neuspell_correct,
24
+ remove_repeated_words,
25
+ remove_trailing_punctuation,
26
+ build_symspell_obj,
27
+ symspeller,
28
+ fix_punct_spacing,
29
+ )
30
+
31
+ from utils import (
32
+ cleantxt_wrap,
33
+ corr,
34
+ )
35
+
36
+ nltk.download("stopwords") # TODO: find where this requirement originates from
37
+
38
+ sys.path.append(dirname(dirname(os.path.abspath(__file__))))
39
+ warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*")
40
+ import transformers
41
+
42
+ transformers.logging.set_verbosity_error()
43
+ logging.basicConfig()
44
+ cwd = Path.cwd()
45
+ my_cwd = str(cwd.resolve()) # string so it can be passed to os.path() objects
46
+
47
+
48
+ def chat(trivia_query):
49
+ """
50
+ chat - helper function that makes the whole gradio thing work.
51
+
52
+ Args:
53
+ trivia_query (str): the question to ask the bot
54
+
55
+ Returns:
56
+ [str]: the bot's response
57
+ """
58
+ history = []
59
+ response = ask_gpt(message=trivia_query, chat_pipe=my_chatbot)
60
+ history = [trivia_query, response]
61
+ html = ""
62
+ for item in history:
63
+ html += f"<b>{item}</b> <br>"
64
+
65
+ html += ""
66
+
67
+ return html
68
+
69
+
70
+ def ask_gpt(
71
+ message: str,
72
+ chat_pipe,
73
+ speaker="person alpha",
74
+ responder="person beta",
75
+ max_len=196,
76
+ top_p=0.95,
77
+ top_k=50,
78
+ temperature=0.6,
79
+ ):
80
+ """
81
+
82
+ ask_gpt - a function that takes in a prompt and generates a response using the pipeline. This interacts the discussion function.
83
+
84
+ Parameters:
85
+ message (str): the question to ask the bot
86
+ chat_pipe (str): the chat_pipe to use for the bot (default: "pszemraj/Ballpark-Trivia-XL")
87
+ speaker (str): the name of the speaker (default: "person alpha")
88
+ responder (str): the name of the responder (default: "person beta")
89
+ max_len (int): the maximum length of the response (default: 128)
90
+ top_p (float): the top probability threshold (default: 0.95)
91
+ top_k (int): the top k threshold (default: 50)
92
+ temperature (float): the temperature of the response (default: 0.7)
93
+ """
94
+
95
+ st = time.perf_counter()
96
+ prompt = clean(message) # clean user input
97
+ prompt = prompt.strip() # get rid of any extra whitespace
98
+ in_len = len(prompt)
99
+ if in_len > 512:
100
+ prompt = prompt[-512:] # truncate to 512 chars
101
+ print(f"Truncated prompt to last 512 chars: started with {in_len} chars")
102
+ max_len = min(max_len, 512)
103
+
104
+ resp = discussion(
105
+ prompt_text=prompt,
106
+ pipeline=chat_pipe,
107
+ speaker=speaker,
108
+ responder=responder,
109
+ top_p=top_p,
110
+ top_k=top_k,
111
+ temperature=temperature,
112
+ max_length=max_len,
113
+ )
114
+ gpt_et = time.perf_counter()
115
+ gpt_rt = round(gpt_et - st, 2)
116
+ rawtxt = resp["out_text"]
117
+ # check for proper nouns
118
+ if basic_sc and not detect_propers(rawtxt):
119
+ cln_resp = symspeller(rawtxt, sym_checker=schnellspell)
120
+ elif not detect_propers(rawtxt):
121
+ cln_resp = neuspell_correct(rawtxt, checker=ns_checker)
122
+ else:
123
+ # no correction needed
124
+ cln_resp = rawtxt.strip()
125
+ bot_resp_a = corr(remove_repeated_words(cln_resp))
126
+ bot_resp = fix_punct_spacing(bot_resp_a)
127
+ print(f"the prompt was:\n\t{message}\nand the response was:\n\t{bot_resp}\n")
128
+ corr_rt = round(time.perf_counter() - gpt_et, 4)
129
+ print(
130
+ f"took {gpt_rt + corr_rt} sec to respond, {gpt_rt} for GPT, {corr_rt} for correction\n"
131
+ )
132
+ return remove_trailing_punctuation(bot_resp)
133
+
134
+
135
+ def get_parser():
136
+ """
137
+ get_parser - a helper function for the argparse module
138
+ """
139
+ parser = argparse.ArgumentParser(
140
+ description="submit a question, GPT model responds"
141
+ )
142
+ parser.add_argument(
143
+ "-m",
144
+ "--model",
145
+ required=False,
146
+ type=str,
147
+ default="pszemraj/GPT-Converse-1pt3B-Neo-WoW-DD-17", # default model
148
+ help="the model to use for the chatbot on https://huggingface.co/models OR a path to a local model",
149
+ )
150
+ parser.add_argument(
151
+ "--basic-sc",
152
+ required=False,
153
+ default=True, # TODO: change this back to False once Neuspell issues are resolved.
154
+ action="store_true",
155
+ help="turn on symspell (baseline) correction instead of the more advanced neural net models",
156
+ )
157
+
158
+ parser.add_argument(
159
+ "--verbose",
160
+ action="store_true",
161
+ default=False,
162
+ help="turn on verbose logging",
163
+ )
164
+ return parser
165
+
166
+
167
+ if __name__ == "__main__":
168
+ args = get_parser().parse_args()
169
+ default_model = str(args.model)
170
+ model_loc = Path(default_model) # if the model is a path, use it
171
+ basic_sc = args.basic_sc # whether to use the baseline spellchecker
172
+ device = 0 if torch.cuda.is_available() else -1
173
+ print(f"CUDA avail is {torch.cuda.is_available()}")
174
+
175
+ my_chatbot = (
176
+ pipeline("text-generation", model=model_loc.resolve(), device=device)
177
+ if model_loc.exists() and model_loc.is_dir()
178
+ else pipeline("text-generation", model=default_model, device=device)
179
+ ) # if the model is a name, use it. stays on CPU if no GPU available
180
+ print(f"using model {my_chatbot.model}")
181
+
182
+ if basic_sc:
183
+ print("Using the baseline spellchecker")
184
+ schnellspell = build_symspell_obj()
185
+ else:
186
+ print("using Neuspell spell checker")
187
+ ns_checker = load_ns_checker(fast=False)
188
+
189
+ print(f"using model stored here: \n {model_loc} \n")
190
+ iface = gr.Interface(
191
+ chat,
192
+ inputs=["text"],
193
+ outputs="html",
194
+ examples_per_page=10,
195
+ examples=[
196
+ "How can you help me?",
197
+ "what can you do?",
198
+ "Hi, my name is……",
199
+ "Happy birthday!",
200
+ "I have a question, can you help me?",
201
+ "Do you know a joke?",
202
+ "Will you marry me?",
203
+ "Are you single?",
204
+ "Do you like people?",
205
+ "Are you part of the Matrix?",
206
+ "Do you have a hobby?",
207
+ "You’re clever",
208
+ "Tell me about your personality",
209
+ "You’re annoying",
210
+ "you suck",
211
+ "I want to speak to a human now.",
212
+ "Don’t you speak English?!",
213
+ "Are you human?",
214
+ "Are you a robot?",
215
+ "What is your name?",
216
+ "How old are you?",
217
+ "What’s your age?",
218
+ "What day is it today?",
219
+ "Who made you?",
220
+ "Which languages can you speak?",
221
+ "What is your mother’s name?",
222
+ "Where do you live?",
223
+ "What’s the weather like today?",
224
+ "Are you expensive?",
225
+ "Do you get smarter?",
226
+ "rate your overall satisfaction with the chatbot",
227
+ "How many icebergs are in the ocean?",
228
+ ],
229
+ title=f"NLP template space: {default_model} Model",
230
+ description=f"this space is used as a template. please copy the files in the space to your own space repo, AND THEN edit them ",
231
+ article="here you can add more details about your model. \n\n"
232
+ "**Important Notes & About:**\n\n"
233
+ "1. the model can take up to 60 seconds to respond sometimes, patience is a virtue.\n"
234
+ "2. the model started from a pretrained checkpoint, and was trained on several different datasets. Anything it says should be fact-checked before being regarded as a true statement.\n"
235
+ "3. Some params are still being tweaked (in the future, will be inputs) any feedback is welcome :)\n",
236
+ css="""
237
+ .chatbox {display:flex;flex-direction:column}
238
+ .user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
239
+ .user_msg {background-color:cornflowerblue;color:white;align-self:start}
240
+ .resp_msg {background-color:lightgray;align-self:self-end}
241
+ """,
242
+ allow_screenshot=True,
243
+ allow_flagging="never",
244
+ theme="dark",
245
+ )
246
+
247
+ # launch the gradio interface and start the server
248
+ iface.launch(
249
+ # prevent_thread_lock=True,
250
+ enable_queue=True, # also allows for dealing with multiple users simultaneously (per newer gradio version)
251
+ )
converse.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ converse.py - this script has functions for handling the conversation between the user and the bot.
3
+
4
+ https://huggingface.co/docs/transformers/v4.15.0/en/main_classes/model#transformers.generation_utils.GenerationMixin.generate.no_repeat_ngram_size
5
+ """
6
+
7
+
8
+ import pprint as pp
9
+ import time
10
+ import torch
11
+ import transformers
12
+
13
+ from grammar_improve import remove_trailing_punctuation
14
+
15
+
16
+ def discussion(
17
+ prompt_text: str,
18
+ speaker: str,
19
+ responder: str,
20
+ pipeline,
21
+ timeout=45,
22
+ max_length=128,
23
+ top_p=0.95,
24
+ top_k=50,
25
+ temperature=0.7,
26
+ full_text=False,
27
+ num_return_sequences=1,
28
+ device=-1,
29
+ verbose=False,
30
+ ):
31
+ """
32
+ discussion - a function that takes in a prompt and generates a response. This function is meant to be used in a conversation loop, and is the main function for the bot.
33
+
34
+ Parameters
35
+ ----------
36
+ prompt_text : str, the prompt to ask the bot, usually the user's question
37
+ speaker : str, the name of the person who is speaking the prompt
38
+ responder : str, the name of the person who is responding to the prompt
39
+ pipeline : transformers.Pipeline, the pipeline to use for generating the response
40
+ timeout : int, optional, the number of seconds to wait before timing out, by default 45
41
+ max_length : int, optional, the maximum number of tokens to generate, defaults to 128
42
+ top_p : float, optional, the top probability to use for sampling, defaults to 0.95
43
+ top_k : int, optional, the top k to use for sampling, defaults to 50
44
+ temperature : float, optional, the temperature to use for sampling, defaults to 0.7
45
+ full_text : bool, optional, whether to return the full text or just the generated text, defaults to False
46
+ num_return_sequences : int, optional, the number of sequences to return, defaults to 1
47
+ device : int, optional, the device to use for generation, defaults to -1 (CPU)
48
+ verbose : bool, optional, whether to print the generated text, defaults to False
49
+
50
+ Returns
51
+ -------
52
+ str, the generated text
53
+ """
54
+
55
+ p_list = [] # track conversation
56
+ p_list.append(speaker.lower() + ":" + "\n")
57
+ p_list.append(prompt_text.lower() + "\n")
58
+ p_list.append("\n")
59
+ p_list.append(responder.lower() + ":" + "\n")
60
+ this_prompt = "".join(p_list)
61
+ if verbose:
62
+ print("overall prompt:\n")
63
+ pp.pprint(this_prompt, indent=4)
64
+ # call the model
65
+ print("\n... generating...")
66
+ bot_dialogue = gen_response(
67
+ this_prompt,
68
+ pipeline,
69
+ speaker,
70
+ responder,
71
+ timeout=timeout,
72
+ max_length=max_length,
73
+ top_p=top_p,
74
+ top_k=top_k,
75
+ temperature=temperature,
76
+ full_text=full_text,
77
+ num_return_sequences=num_return_sequences,
78
+ device=device,
79
+ verbose=verbose,
80
+ )
81
+ if isinstance(bot_dialogue, list) and len(bot_dialogue) > 1:
82
+ bot_resp = ", ".join(bot_dialogue)
83
+ elif isinstance(bot_dialogue, list) and len(bot_dialogue) == 1:
84
+ bot_resp = bot_dialogue[0]
85
+ else:
86
+ bot_resp = bot_dialogue
87
+ bot_resp = " ".join(bot_resp) if isinstance(bot_resp, list) else bot_resp
88
+ bot_resp = bot_resp.strip()
89
+ # remove the last ',' '.' chars
90
+ bot_resp = remove_trailing_punctuation(bot_resp)
91
+ if verbose:
92
+ print("\n... bot response:\n")
93
+ pp.pprint(bot_resp)
94
+ p_list.append(bot_resp + "\n")
95
+ p_list.append("\n")
96
+
97
+ print("\nfinished!")
98
+ # return the bot response and the full conversation
99
+
100
+ return {"out_text": bot_resp, "full_conv": p_list}
101
+
102
+
103
+ def gen_response(
104
+ query: str,
105
+ pipeline,
106
+ speaker: str,
107
+ responder: str,
108
+ timeout=45,
109
+ max_length=128,
110
+ top_p=0.95,
111
+ top_k=50,
112
+ temperature=0.7,
113
+ full_text=False,
114
+ num_return_sequences=1,
115
+ device=-1,
116
+ verbose=False,
117
+ **kwargs,
118
+ ):
119
+ """
120
+ gen_response - a function that takes in a prompt and generates a response using the pipeline. This operates underneath the discussion function.
121
+
122
+ Parameters
123
+ ----------
124
+ query : str, the prompt to ask the bot, usually the user's question
125
+ speaker : str, the name of the person who is speaking the prompt
126
+ responder : str, the name of the person who is responding to the prompt
127
+ pipeline : transformers.Pipeline, the pipeline to use for generating the response
128
+ timeout : int, optional, the number of seconds to wait before timing out, by default 45
129
+ max_length : int, optional, the maximum number of tokens to generate, defaults to 128
130
+ top_p : float, optional, the top probability to use for sampling, defaults to 0.95
131
+ top_k : int, optional, the top k to use for sampling, defaults to 50
132
+ temperature : float, optional, the temperature to use for sampling, defaults to 0.7
133
+ full_text : bool, optional, whether to return the full text or just the generated text, defaults to False
134
+ num_return_sequences : int, optional, the number of sequences to return, defaults to 1
135
+ device : int, optional, the device to use for generation, defaults to -1 (CPU)
136
+ verbose : bool, optional, whether to print the generated text, defaults to False
137
+
138
+ Returns
139
+ -------
140
+ str, the generated text
141
+
142
+ """
143
+
144
+ if max_length > 1024:
145
+ max_length = 1024
146
+ print("max_length is too large, setting to 1024")
147
+ st = time.perf_counter()
148
+
149
+ response = pipeline(
150
+ query,
151
+ max_length=max_length,
152
+ temperature=temperature,
153
+ top_k=top_k,
154
+ top_p=top_p,
155
+ num_return_sequences=num_return_sequences,
156
+ max_time=timeout,
157
+ return_full_text=full_text,
158
+ no_repeat_ngram_size=3,
159
+ length_penalty=0.3,
160
+ repetition_penalty=3.4,
161
+ clean_up_tokenization_spaces=True,
162
+ **kwargs,
163
+ ) # the likely better beam-less method
164
+ rt = round(time.perf_counter() - st, 2)
165
+ if verbose:
166
+ print(f"took {rt} sec to respond")
167
+
168
+ if verbose:
169
+ print("\n[DEBUG] generated:\n")
170
+ pp.pprint(response) # for debugging
171
+ # process the full result to get the ~bot response~ piece
172
+ this_result = str(response[0]["generated_text"]).split(
173
+ "\n"
174
+ ) # TODO: adjust hardcoded value for index to dynamic (if n>1)
175
+
176
+ bot_dialogue = consolidate_texts(
177
+ name_resp=responder,
178
+ model_resp=this_result,
179
+ name_spk=speaker,
180
+ verbose=verbose,
181
+ print_debug=True,
182
+ )
183
+ if verbose:
184
+ print(f"DEBUG: {bot_dialogue} was original response pre-SC")
185
+
186
+ return bot_dialogue #
187
+
188
+
189
+ def consolidate_texts(
190
+ model_resp: list,
191
+ name_resp: str = None,
192
+ name_spk: str = None,
193
+ verbose=False,
194
+ print_debug=False,
195
+ ):
196
+ """
197
+ consolidate_texts - given a list with speaker name followed by speaker text, returns all consecutive values of the first speaker name
198
+
199
+ Parameters:
200
+ name_resp (str): the name of the person who is responding
201
+ model_resp (list): the list of strings to consolidate (usually from the model)
202
+ name_spk (str): the name of the person who is speaking
203
+ verbose (bool): whether to print the results
204
+ print_debug (bool): whether to print the debug info during looping
205
+
206
+ Returns:
207
+ list, a list of all the consecutive messages of the first speaker name
208
+ """
209
+ assert len(model_resp) > 0, "model_resp is empty"
210
+ if len(model_resp) == 1:
211
+ return model_resp[0]
212
+ name_resp = "person beta" if name_resp is None else name_resp
213
+ name_spk = "person alpha" if name_spk is None else name_spk
214
+ if verbose:
215
+ print("====" * 10)
216
+ print(f"\n[DEBUG] initial model_resp has {len(model_resp)} lines: \n\t{model_resp}")
217
+ print(f" the first element is \n\t{model_resp[0]} and it is {type(model_resp[0])}")
218
+ fn_resp = []
219
+
220
+ name_counter = 0
221
+ break_safe = False
222
+ for resline in model_resp:
223
+ if name_resp.lower() in resline:
224
+ name_counter += 1
225
+ break_safe = True # know the line is from bot as this line starts with the name of the bot
226
+ continue # don't add this line to the list
227
+ if name_spk.lower() in resline.lower():
228
+ if print_debug:
229
+ print(f"\nDEBUG: \n\t{resline}\ncaused the break")
230
+ break # the name of the speaker is in the line, so we're done
231
+ if any([": " in resline,":\n" in resline]) and name_resp.lower() not in resline.lower():
232
+ if print_debug:
233
+ print(f"\nDEBUG: \n\t{resline}\ncaused the break")
234
+ break
235
+ else:
236
+ fn_resp.append(resline)
237
+ break_safe = False
238
+ if verbose:
239
+ print("--" * 10)
240
+ print("\nthe full response is:\n")
241
+ print("\n".join(fn_resp))
242
+ print("--" * 10)
243
+
244
+ return fn_resp
grammar_improve.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ grammar_improve.py - this .py script contains functions to improve the grammar of a user's input or the models output.
3
+
4
+ """
5
+
6
+ from datetime import datetime
7
+ import os
8
+ import pprint as pp
9
+ from neuspell import BertChecker, SclstmChecker
10
+ import neuspell
11
+ import math
12
+ from cleantext import clean
13
+ import time
14
+ import re
15
+ import sys
16
+ from symspellpy.symspellpy import SymSpell
17
+
18
+ from utils import suppress_stdout
19
+
20
+
21
+ def detect_propers(text: str):
22
+ """
23
+ detect_propers - detect if a string contains proper nouns
24
+
25
+ Args:
26
+ text (str): [string to be checked]
27
+
28
+ Returns:
29
+ [bool]: [True if string contains proper nouns]
30
+ """
31
+ pat = re.compile(r"(?:\w+['’])?\w+(?:-(?:\w+['’])?\w+)*")
32
+ return bool(pat.search(text))
33
+
34
+
35
+ def fix_punct_spaces(string):
36
+ """
37
+ fix_punct_spaces - replace spaces around punctuation with punctuation. For example, "hello , there" -> "hello, there"
38
+
39
+ Parameters
40
+ ----------
41
+ string : str, required, input string to be corrected
42
+
43
+ Returns
44
+ -------
45
+ str, corrected string
46
+ """
47
+
48
+ fix_spaces = re.compile(r"\s*([?!.,]+(?:\s+[?!.,]+)*)\s*")
49
+ string = fix_spaces.sub(lambda x: "{} ".format(x.group(1).replace(" ", "")), string)
50
+ return string.strip()
51
+
52
+
53
+ def split_sentences(text: str):
54
+ """
55
+ split_sentences - split a string into a list of sentences that keep their ending punctuation. powered by regex witchcraft
56
+
57
+ Args:
58
+ text (str): [string to be split]
59
+
60
+ Returns:
61
+ [list]: [list of strings]
62
+ """
63
+ return re.split(r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s", text)
64
+
65
+
66
+ def remove_repeated_words(bot_response):
67
+ """
68
+ remove_repeated_words - remove repeated words from a string, returning only the first instance of each word
69
+
70
+ Parameters
71
+ ----------
72
+ bot_response : str
73
+ string to remove repeated words from
74
+
75
+ Returns
76
+ -------
77
+ str
78
+ string containing the first instance of each word
79
+ """
80
+ words = bot_response.split()
81
+ unique_words = []
82
+ for word in words:
83
+ if word not in unique_words:
84
+ unique_words.append(word)
85
+ return " ".join(unique_words)
86
+
87
+
88
+ def remove_trailing_punctuation(text: str, fuLL_strip=False):
89
+ """
90
+ remove_trailing_punctuation - remove trailing punctuation from a string. Purpose is to seem more natural to end users
91
+
92
+ Args:
93
+ text (str): [string to be cleaned]
94
+
95
+ Returns:
96
+ [str]: [cleaned string]
97
+ """
98
+ if fuLL_strip:
99
+ return text.strip("?!.,;:")
100
+ else:
101
+ return text.strip(".,;:")
102
+
103
+
104
+ def fix_punct_spacing(text: str):
105
+ fix_spaces = re.compile(r"\s*([?!.,]+(?:\s+[?!.,]+)*)\s*")
106
+ spc_text = fix_spaces.sub(lambda x: "{} ".format(x.group(1).replace(" ", "")), text)
107
+ cln_text = re.sub(r"(\W)(?=\1)", "", spc_text)
108
+
109
+ return cln_text
110
+
111
+
112
+ """
113
+ start of SymSpell code
114
+ """
115
+
116
+
117
+ def symspeller(
118
+ my_string: str,
119
+ sym_checker=None,
120
+ max_dist: int = 2,
121
+ prefix_length: int = 7,
122
+ ignore_non_words=True,
123
+ dictionary_path: str = None,
124
+ bigram_path: str = None,
125
+ verbose=False,
126
+ ):
127
+ """
128
+ symspeller - a wrapper for the SymSpell class from symspellpy
129
+
130
+ Parameters
131
+ ----------
132
+ my_string : str, required, default=None, the string to be checked
133
+ sym_checker : SymSpell, optional, default=None, the SymSpell object to use
134
+ max_dist : int, optional, default=3, the maximum distance to look for replacements
135
+ prefix_length : int, optional, default=7, the length of the prefixes to use
136
+ ignore_non_words : bool, optional, default=True, whether to ignore non-words
137
+ dictionary_path : str, optional, default=None, the path to the dictionary file
138
+ bigram_path : str, optional, default=None, the path to the bigram dictionary file
139
+ verbose : bool, optional, default=False, whether to print the results
140
+
141
+ Returns
142
+ -------
143
+ list,
144
+
145
+ """
146
+
147
+ assert len(my_string) > 0, "entered string for correction is empty"
148
+
149
+ if sym_checker is None:
150
+ # need to create a new class object. user can specify their own dictionary and bigram files
151
+ if verbose:
152
+ print("creating new SymSpell object")
153
+ sym_checker = build_symspell_obj(
154
+ edit_dist=max_dist,
155
+ prefix_length=prefix_length,
156
+ dictionary_path=dictionary_path,
157
+ bigram_path=bigram_path,
158
+ )
159
+ else:
160
+ if verbose:
161
+ print("using existing SymSpell object")
162
+ # max edit distance per lookup (per single word, not per whole input string)
163
+ suggestions = sym_checker.lookup_compound(
164
+ my_string,
165
+ max_edit_distance=max_dist,
166
+ ignore_non_words=ignore_non_words,
167
+ ignore_term_with_digits=True,
168
+ transfer_casing=True,
169
+ )
170
+
171
+ if verbose:
172
+ print(f"{len(suggestions)} suggestions found")
173
+ print(f"the original string is:\n\t{my_string}")
174
+ sug_list = [sug.term for sug in suggestions]
175
+ print(f"suggestions:\n\t{sug_list}\n")
176
+
177
+ if len(suggestions) < 1:
178
+ return clean(my_string) # no correction because no suggestions
179
+ else:
180
+ first_result = suggestions[0] # first result is the most likely
181
+ return first_result._term
182
+
183
+
184
+ def build_symspell_obj(
185
+ edit_dist=2,
186
+ prefix_length=7,
187
+ dictionary_path=None,
188
+ bigram_path=None,
189
+ ):
190
+ """
191
+ build_symspell_obj [build a SymSpell object]
192
+
193
+ Args:
194
+ verbose (bool, optional): Defaults to False.
195
+
196
+ Returns:
197
+ SymSpell: a SymSpell object
198
+ """
199
+ dictionary_path = (
200
+ r"symspell_rsc/frequency_dictionary_en_82_765.txt"
201
+ if dictionary_path is None
202
+ else dictionary_path
203
+ )
204
+ bigram_path = (
205
+ r"symspell_rsc/frequency_bigramdictionary_en_243_342.txt"
206
+ if bigram_path is None
207
+ else bigram_path
208
+ )
209
+ sym_checker = SymSpell(
210
+ max_dictionary_edit_distance=edit_dist + 2, prefix_length=prefix_length
211
+ )
212
+ # term_index is the column of the term and count_index is the
213
+ # column of the term frequency
214
+ sym_checker.load_dictionary(dictionary_path, term_index=0, count_index=1)
215
+ sym_checker.load_bigram_dictionary(bigram_path, term_index=0, count_index=2)
216
+
217
+ return sym_checker
218
+
219
+
220
+ """
221
+ # if using t5b_correction to check for spelling errors, use this code to initialize the objects
222
+
223
+ import torch
224
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
225
+
226
+ model_name = 'deep-learning-analytics/GrammarCorrector'
227
+ # torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
228
+ torch_device = 'cpu'
229
+ gc_tokenizer = T5Tokenizer.from_pretrained(model_name)
230
+ gc_model = T5ForConditionalGeneration.from_pretrained(model_name).to(torch_device)
231
+
232
+ """
233
+
234
+
235
+ def t5b_correction(prompt: str, korrektor, verbose=False, beams=4):
236
+ """
237
+ t5b_correction - correct a string using a text2textgen pipeline model from transformers
238
+
239
+ Parameters
240
+ ----------
241
+ prompt : str, required, input prompt to be corrected
242
+ korrektor : transformers.pipeline, required, pipeline object
243
+ verbose : bool, optional, whether to print the corrected prompt. Defaults to False.
244
+ beams : int, optional, number of beams to use for the correction. Defaults to 4.
245
+
246
+ Returns
247
+ -------
248
+ str, corrected prompt
249
+ """
250
+
251
+ p_min_len = int(math.ceil(0.9 * len(prompt)))
252
+ p_max_len = int(math.ceil(1.1 * len(prompt)))
253
+ if verbose:
254
+ print(f"setting min to {p_min_len} and max to {p_max_len}\n")
255
+ gcorr_result = korrektor(
256
+ f"grammar: {prompt}",
257
+ return_text=True,
258
+ clean_up_tokenization_spaces=True,
259
+ num_beams=beams,
260
+ max_length=p_max_len,
261
+ repetition_penalty=1.3,
262
+ length_penalty=0.2,
263
+ no_repeat_ngram_size=2,
264
+ )
265
+ if verbose:
266
+ print(f"grammar correction result: \n\t{gcorr_result}\n")
267
+ return gcorr_result
268
+
269
+
270
+ def all_neuspell_chkrs():
271
+ """
272
+ disp_neuspell_chkrs - display the neuspell checkers available
273
+
274
+ Parameters
275
+ ----------
276
+ None
277
+
278
+ Returns
279
+ -------
280
+ checker_opts - list of checkers available
281
+ """
282
+
283
+ checker_opts = dir(neuspell)
284
+ print(f"\navailable checkers:")
285
+
286
+ pp.pprint(checker_opts, indent=4, compact=True)
287
+
288
+ return checker_opts
289
+
290
+
291
+ def load_ns_checker(customckr=None, fast=False):
292
+ """
293
+ load_ns_checker - helper function, load / "set up" a neuspell checker from huggingface transformers
294
+
295
+ Args:
296
+ customckr (neuspell.NeuSpell): [neuspell checker object], optional, if not provided, will load the default checker
297
+
298
+ Returns:
299
+ [neuspell.NeuSpell]: [neuspell checker object]
300
+ """
301
+ st = time.perf_counter()
302
+ # stop all printing to the console
303
+ with suppress_stdout():
304
+ if customckr is None and not fast:
305
+
306
+ checker = BertChecker(
307
+ pretrained=True
308
+ ) # load the default checker, has the best balance
309
+ elif customckr is None and fast:
310
+ checker = SclstmChecker(
311
+ pretrained=True
312
+ ) # this one is faster but not as accurate
313
+ else:
314
+ checker = customckr(pretrained=True)
315
+ rt_min = (time.perf_counter() - st) / 60
316
+ # return to standard logging level
317
+ print(f"\n\nloaded checker in {rt_min} minutes")
318
+
319
+ return checker
320
+
321
+
322
+ def neuspell_correct(input_text: str, checker=None, verbose=False):
323
+ """
324
+ neuspell_correct - correct a string using neuspell.
325
+ note that modificaitons to the checker are needed if doing list-based corrections
326
+
327
+ Parameters
328
+ ----------
329
+ input_text : str, required, input string to be corrected
330
+ checker : neuspell.NeuSpell, optional, neuspell checker object. Defaults to None.
331
+ verbose : bool, optional, whether to print the corrected string. Defaults to False.
332
+
333
+ Returns
334
+ -------
335
+ str, corrected string
336
+ """
337
+ if isinstance(input_text, str) and len(input_text) < 4:
338
+ print(f"input text of {input_text} is too short to be corrected")
339
+ return input_text
340
+
341
+ if checker is None:
342
+ print("NOTE - no checker provided, loading default checker")
343
+ checker = SclstmChecker(pretrained=True)
344
+
345
+ corrected = checker.correct(input_text)
346
+ cleaned_txt = fix_punct_spaces(corrected)
347
+
348
+ if verbose:
349
+ print(f"neuspell correction result: \n\t{cleaned_txt}\n")
350
+ return cleaned_txt
351
+
352
+
353
+ def grammarpipe(corrector, qphrase: str):
354
+ """
355
+ gramformer_correct - THE ORIGINAL ONE USED IN PROJECT AND NEEDS TO BE CHANGED.
356
+ Idea is to correct a string using a text2textgen pipeline model from transformers
357
+ Args:
358
+ corrector (transformers.pipeline): [transformers pipeline object, already created w/ relevant model]
359
+ qphrase (str): [text to be corrected]
360
+ Returns:
361
+ [str]: [corrected text]
362
+ """
363
+ if isinstance(qphrase, str) and len(qphrase) < 4:
364
+ print(f"input text of {qphrase} is too short to be corrected")
365
+ return qphrase
366
+ try:
367
+ corrected = corrector(
368
+ clean(qphrase), return_text=True, clean_up_tokenization_spaces=True
369
+ )
370
+ return corrected[0]["generated_text"]
371
+ except Exception as e:
372
+ print(f"NOTE - failed to correct with grammarpipe:\n {e}")
373
+ return clean(qphrase)
374
+
375
+
376
+ def DLA_correct(qphrase: str):
377
+ """
378
+ DLA_correct - an "overhead" function to call correct_grammar() on a string, allowing for each newline to be corrected individually
379
+
380
+ Args:
381
+ qphrase (str): [string to be corrected]
382
+
383
+ Returns:
384
+ str, the list of the corrected strings joined under " "
385
+ """
386
+ if isinstance(qphrase, str) and len(qphrase) < 4:
387
+ print(f"input text of {qphrase} is too short to be corrected")
388
+ return qphrase
389
+
390
+ sentences = split_sentences(qphrase)
391
+ if len(sentences) == 1:
392
+ corrected = correct_grammar(sentences[0])
393
+ return corrected
394
+ else:
395
+ full_cor = []
396
+ for sen in sentences:
397
+ corr_sen = correct_grammar(clean(sen))
398
+ full_cor.append(corr_sen)
399
+ return " ".join(full_cor)
400
+
401
+
402
+ def correct_grammar(
403
+ input_text: str,
404
+ tokenizer,
405
+ model,
406
+ n_results: int = 1,
407
+ beams: int = 8,
408
+ temp=1,
409
+ uniq_ngrams=2,
410
+ rep_penalty=1.5,
411
+ device="cpu",
412
+ ):
413
+ """
414
+ correct_grammar - correct a string using a text2textgen pipeline model from transformers.
415
+ This function is an alternative to the t5b_correction function.
416
+
417
+ Parameters
418
+ ----------
419
+ input_text : str, required, input string to be corrected
420
+ tokenizer : transformers.T5Tokenizer, required, tokenizer object, already created w/ relevant model
421
+ model : transformers.T5ForConditionalGeneration, required, model object, already created w/ relevant model
422
+ n_results : int, optional, number of results to return. Defaults to 1.
423
+ beams : int, optional, number of beams to use for the correction. Defaults to 8.
424
+ temp : int, optional, temperature to use for the correction. Defaults to 1.
425
+ uniq_ngrams : int, optional, number of ngrams to use for the correction. Defaults to 2.
426
+ rep_penalty : float, optional, penalty to use for the correction. Defaults to 1.5.
427
+ device : str, optional, device to use for the correction. Defaults to 'cpu'.
428
+
429
+ Returns
430
+ -------
431
+ str, corrected string (or list of strings if n_results > 1)
432
+ """
433
+ st = time.perf_counter()
434
+
435
+ if len(input_text) < 5:
436
+ return input_text
437
+ max_length = min(int(math.ceil(len(input_text) * 1.2)), 128)
438
+ batch = tokenizer(
439
+ [input_text],
440
+ truncation=True,
441
+ padding="max_length",
442
+ max_length=max_length,
443
+ return_tensors="pt",
444
+ ).to(device)
445
+ translated = model.generate(
446
+ **batch,
447
+ max_length=max_length,
448
+ min_length=min(10, len(input_text)),
449
+ no_repeat_ngram_size=uniq_ngrams,
450
+ repetition_penalty=rep_penalty,
451
+ num_beams=beams,
452
+ num_return_sequences=n_results,
453
+ temperature=temp,
454
+ )
455
+
456
+ tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
457
+ rt_min = (time.perf_counter() - st) / 60
458
+ print(f"\n\ncorrected in {rt_min} minutes")
459
+
460
+ if isinstance(tgt_text, list):
461
+ return tgt_text[0]
462
+ else:
463
+ return tgt_text
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers>=4.12.5
2
+ sentencepiece>=0.1.96
3
+ tqdm>=4.43.0
4
+ symspellpy>=6.7.0
5
+ requests>=2.24.0
6
+ gradio>=2.4.6
7
+ natsort>=7.1.1
8
+ pandas>=1.3.0
9
+ aitextgen>=0.5.2
10
+ clean-text>=0.5.0
11
+ openwa>=1.3.16
12
+ python-telegram-bot>=13.0
13
+ webwhatsapi>=2.0.5
14
+ Flask>=2.0.2
15
+ nltk>=3.6.6
16
+ neuspell>=1.0.0
symspell_rsc/frequency_bigramdictionary_en_243_342.txt ADDED
The diff for this file is too large to render. See raw diff
 
symspell_rsc/frequency_dictionary_en_82_765.txt ADDED
The diff for this file is too large to render. See raw diff
 
utils.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ utils - general utility functions for loading, saving, and manipulating data
3
+ """
4
+
5
+ import os
6
+ from pathlib import Path
7
+ import pprint as pp
8
+ import re
9
+ import shutil # zipfile formats
10
+ from datetime import datetime
11
+ from os.path import basename
12
+ from os.path import getsize, join
13
+
14
+ import requests
15
+ from cleantext import clean
16
+ from natsort import natsorted
17
+ from symspellpy import SymSpell
18
+ import pandas as pd
19
+ from tqdm.auto import tqdm
20
+
21
+
22
+ from contextlib import contextmanager
23
+ import sys
24
+ import os
25
+
26
+
27
+ @contextmanager
28
+ def suppress_stdout():
29
+ """
30
+ suppress_stdout - suppress stdout for a given block of code. credit to https://newbedev.com/how-to-suppress-console-output-in-python
31
+ """
32
+ with open(os.devnull, "w") as devnull:
33
+ old_stdout = sys.stdout
34
+ sys.stdout = devnull
35
+ try:
36
+ yield
37
+ finally:
38
+ sys.stdout = old_stdout
39
+
40
+
41
+ def remove_string_extras(mytext):
42
+ # removes everything from a string except A-Za-z0-9 .,;
43
+ return re.sub(r"[^A-Za-z0-9 .,;]+", "", mytext)
44
+
45
+
46
+ def corr(s):
47
+ # adds space after period if there isn't one
48
+ # removes extra spaces
49
+ return re.sub(r"\.(?! )", ". ", re.sub(r" +", " ", s))
50
+
51
+
52
+ def get_timestamp():
53
+ # get timestamp for file names
54
+ return datetime.now().strftime("%b-%d-%Y_t-%H")
55
+
56
+
57
+ def print_spacer(n=1):
58
+ """print_spacer - print a spacer line"""
59
+ print("\n -------- " * n)
60
+
61
+
62
+ def fast_scandir(dirname: str):
63
+ """
64
+ fast_scandir [an os.path-based means to return all subfolders in a given filepath]
65
+
66
+ """
67
+
68
+ subfolders = [f.path for f in os.scandir(dirname) if f.is_dir()]
69
+ for dirname in list(subfolders):
70
+ subfolders.extend(fast_scandir(dirname))
71
+ return subfolders # list
72
+
73
+
74
+ def create_folder(directory: str):
75
+ # you will never guess what this does
76
+ os.makedirs(directory, exist_ok=True)
77
+
78
+
79
+ def chunks(lst: list, n: int):
80
+ """
81
+ chunks - Yield successive n-sized chunks from lst
82
+ Args: lst (list): list to be chunked
83
+ n (int): size of chunks
84
+
85
+ """
86
+
87
+ for i in range(0, len(lst), n):
88
+ yield lst[i : i + n]
89
+
90
+
91
+ def chunky_pandas(my_df, num_chunks: int = 4):
92
+ """
93
+ chunky_pandas [split dataframe into `num_chunks` equal chunks, return each inside a list]
94
+
95
+ Args:
96
+ my_df (pd.DataFrame)
97
+ num_chunks (int, optional): Defaults to 4.
98
+
99
+ Returns:
100
+ list: a list of dataframes
101
+ """
102
+ n = int(len(my_df) // num_chunks)
103
+ list_df = [my_df[i : i + n] for i in range(0, my_df.shape[0], n)]
104
+
105
+ return list_df
106
+
107
+
108
+ def load_dir_files(
109
+ directory: str, req_extension=".txt", return_type="list", verbose=False
110
+ ):
111
+ """
112
+ load_dir_files - an os.path based method of returning all files with extension `req_extension` in a given directory and subdirectories
113
+
114
+ Args:
115
+
116
+
117
+ Returns:
118
+ list or dict: an iterable of filepaths or a dict of filepaths and their respective filenames
119
+ """
120
+ appr_files = []
121
+ # r=root, d=directories, f = files
122
+ for r, d, f in os.walk(directory):
123
+ for prefile in f:
124
+ if prefile.endswith(req_extension):
125
+ fullpath = os.path.join(r, prefile)
126
+ appr_files.append(fullpath)
127
+
128
+ appr_files = natsorted(appr_files)
129
+
130
+ if verbose:
131
+ print("A list of files in the {} directory are: \n".format(directory))
132
+ if len(appr_files) < 10:
133
+ pp.pprint(appr_files)
134
+ else:
135
+ pp.pprint(appr_files[:10])
136
+ print("\n and more. There are a total of {} files".format(len(appr_files)))
137
+
138
+ if return_type.lower() == "list":
139
+ return appr_files
140
+ else:
141
+ if verbose:
142
+ print("returning dictionary")
143
+
144
+ appr_file_dict = {}
145
+ for this_file in appr_files:
146
+ appr_file_dict[basename(this_file)] = this_file
147
+
148
+ return appr_file_dict
149
+
150
+
151
+ def URL_string_filter(text):
152
+ """
153
+ URL_string_filter - filter out nonstandard "text" characters
154
+
155
+ """
156
+ custom_printable = (
157
+ "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ._"
158
+ )
159
+
160
+ filtered = "".join((filter(lambda i: i in custom_printable, text)))
161
+
162
+ return filtered
163
+
164
+
165
+ def getFilename_fromCd(cd):
166
+ """getFilename_fromCd - get the filename from a given cd str"""
167
+ if not cd:
168
+ return None
169
+ fname = re.findall("filename=(.+)", cd)
170
+ if len(fname) > 0:
171
+ output = fname[0]
172
+ elif cd.find("/"):
173
+ possible_fname = cd.rsplit("/", 1)[1]
174
+ output = URL_string_filter(possible_fname)
175
+ else:
176
+ output = None
177
+ return output
178
+
179
+
180
+ def get_zip_URL(
181
+ URLtoget: str,
182
+ extract_loc: str = None,
183
+ file_header: str = "dropboxexport_",
184
+ verbose: bool = False,
185
+ ):
186
+ """get_zip_URL - download a zip file from a given URL and extract it to a given location"""
187
+
188
+ r = requests.get(URLtoget, allow_redirects=True)
189
+ names = getFilename_fromCd(r.headers.get("content-disposition"))
190
+ fixed_fnames = names.split(";") # split the multiple results
191
+ this_filename = file_header + URL_string_filter(fixed_fnames[0])
192
+
193
+ # define paths and save the zip file
194
+ if extract_loc is None:
195
+ extract_loc = "dropbox_dl"
196
+ dl_place = join(os.getcwd(), extract_loc)
197
+ create_folder(dl_place)
198
+ save_loc = join(os.getcwd(), this_filename)
199
+ open(save_loc, "wb").write(r.content)
200
+ if verbose:
201
+ print("downloaded file size was {} MB".format(getsize(save_loc) / 1000000))
202
+
203
+ # unpack the archive
204
+ shutil.unpack_archive(save_loc, extract_dir=dl_place)
205
+ if verbose:
206
+ print("extracted zip file - ", datetime.now())
207
+ x = load_dir_files(dl_place, req_extension="", verbose=verbose)
208
+
209
+ # remove original
210
+ try:
211
+ os.remove(save_loc)
212
+ del save_loc
213
+ except Exception:
214
+ print("unable to delete original zipfile - check if exists", datetime.now())
215
+
216
+ print("finished extracting zip - ", datetime.now())
217
+
218
+ return dl_place
219
+
220
+
221
+ def merge_dataframes(data_dir: str, ext=".xlsx", verbose=False):
222
+ """
223
+ merge_dataframes - given a filepath, loads and attempts to merge all files as dataframes
224
+
225
+ Args:
226
+ data_dir (str): [root directory to search in]
227
+ ext (str, optional): [anticipate file extension for the dataframes ]. Defaults to '.xlsx'.
228
+
229
+ Returns:
230
+ pd.DataFrame(): merged dataframe of all files
231
+ """
232
+
233
+ src = Path(data_dir)
234
+ src_str = str(src.resolve())
235
+ mrg_df = pd.DataFrame()
236
+
237
+ all_reports = load_dir_files(directory=src_str, req_extension=ext, verbose=verbose)
238
+
239
+ failed = []
240
+
241
+ for df_path in tqdm(all_reports, total=len(all_reports), desc="joining data..."):
242
+
243
+ try:
244
+ this_df = pd.read_excel(df_path).convert_dtypes()
245
+
246
+ mrg_df = pd.concat([mrg_df, this_df], axis=0)
247
+ except Exception:
248
+ short_p = os.path.basename(df_path)
249
+ print(
250
+ f"WARNING - file with extension {ext} and name {short_p} could not be read."
251
+ )
252
+ failed.append(short_p)
253
+
254
+ if len(failed) > 0:
255
+ print("failed to merge {} files, investigate as needed")
256
+
257
+ if verbose:
258
+ pp.pprint(mrg_df.info(True))
259
+
260
+ return mrg_df
261
+
262
+
263
+ def download_URL(url: str, file=None, dlpath=None, verbose=False):
264
+ """
265
+ download_URL - download a file from a URL and show progress bar
266
+
267
+ Parameters
268
+ ----------
269
+ url : str
270
+ URL to download
271
+ file : [type], optional
272
+ [description], by default None
273
+ dlpath : [type], optional
274
+ [description], by default None
275
+ verbose : bool, optional
276
+ [description], by default False
277
+
278
+ Returns
279
+ -------
280
+ str - path to the downloaded file
281
+ """
282
+
283
+ if file is None:
284
+ if "?dl=" in url:
285
+ # is a dropbox link
286
+ prefile = url.split("/")[-1]
287
+ filename = str(prefile).split("?dl=")[0]
288
+ else:
289
+ filename = url.split("/")[-1]
290
+
291
+ file = clean(filename)
292
+ if dlpath is None:
293
+ dlpath = Path.cwd() # save to current working directory
294
+ else:
295
+ dlpath = Path(dlpath) # make a path object
296
+
297
+ r = requests.get(url, stream=True, allow_redirects=True)
298
+ total_size = int(r.headers.get("content-length"))
299
+ initial_pos = 0
300
+ dl_loc = dlpath / file
301
+ with open(str(dl_loc.resolve()), "wb") as f:
302
+ with tqdm(
303
+ total=total_size,
304
+ unit="B",
305
+ unit_scale=True,
306
+ desc=file,
307
+ initial=initial_pos,
308
+ ascii=True,
309
+ ) as pbar:
310
+ for ch in r.iter_content(chunk_size=1024):
311
+ if ch:
312
+ f.write(ch)
313
+ pbar.update(len(ch))
314
+
315
+ if verbose:
316
+ print(f"\ndownloaded {file} to {dlpath}\n")
317
+
318
+ return str(dl_loc.resolve())
319
+
320
+
321
+ def dl_extract_zip(
322
+ URLtoget: str,
323
+ extract_loc: str = None,
324
+ file_header: str = "TEMP_archive_dl_",
325
+ verbose: bool = False,
326
+ ):
327
+ """
328
+ dl_extract_zip - generic function to download a zip file and extract it
329
+
330
+ Parameters
331
+ ----------
332
+ URLtoget : str
333
+ zip file URL to download
334
+ extract_loc : str, optional
335
+ directory to extract zip to , by default None
336
+ file_header : str, optional
337
+ [description], by default "TEMP_archive_dl_"
338
+ verbose : bool, optional
339
+ [description], by default False
340
+
341
+ Returns
342
+ -------
343
+ str - path to the downloaded and extracted folder
344
+ """
345
+
346
+ extract_loc = Path(extract_loc)
347
+ extract_loc.mkdir(parents=True, exist_ok=True)
348
+
349
+ save_loc = download_URL(
350
+ url=URLtoget, file=f"{file_header}.zip", dlpath=None, verbose=verbose
351
+ )
352
+
353
+ shutil.unpack_archive(save_loc, extract_dir=extract_loc)
354
+
355
+ if verbose:
356
+ print("extracted zip file - ", datetime.now())
357
+ x = load_dir_files(extract_loc, req_extension="", verbose=verbose)
358
+
359
+ # remove original
360
+ try:
361
+ os.remove(save_loc)
362
+ del save_loc
363
+ except Exception:
364
+ print("unable to delete original zipfile - check if exists", datetime.now())
365
+
366
+ if verbose:
367
+ print("finished extracting zip - ", datetime.now())
368
+
369
+ return extract_loc
370
+
371
+
372
+ def cleantxt_wrap(ugly_text, all_lower=False):
373
+ """
374
+ cleantxt_wrap - applies the clean function to a string.
375
+
376
+ Args:
377
+ ugly_text (str): [string to be cleaned]
378
+
379
+ Returns:
380
+ [str]: [cleaned string]
381
+ """
382
+ if isinstance(ugly_text, str) and len(ugly_text) > 0:
383
+ return clean(ugly_text, lower=all_lower)
384
+ else:
385
+ return ugly_text