Spaces:
Runtime error
Runtime error
| """ | |
| app.py - the main file for the app. This creates the flask app and handles the routes. | |
| """ | |
| import argparse | |
| import logging | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s") | |
| import os | |
| import sys | |
| import time | |
| import warnings | |
| from os.path import dirname | |
| from pathlib import Path | |
| import gradio as gr | |
| import nltk | |
| import torch | |
| from cleantext import clean | |
| from gradio.inputs import Slider, Textbox, Radio | |
| from transformers import pipeline | |
| from converse import discussion | |
| from grammar_improve import ( | |
| build_symspell_obj, | |
| detect_propers, | |
| fix_punct_spacing, | |
| load_ns_checker, | |
| neuspell_correct, | |
| remove_repeated_words, | |
| remove_trailing_punctuation, | |
| symspeller, | |
| synthesize_grammar, | |
| ) | |
| from utils import corr, setup_logging | |
| nltk.download("stopwords") # download stopwords | |
| sys.path.append(dirname(dirname(os.path.abspath(__file__)))) | |
| warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*") | |
| import transformers | |
| transformers.logging.set_verbosity_error() | |
| cwd = Path.cwd() | |
| _cwd_str = str(cwd.resolve()) # string so it can be passed to os.path() objects | |
| def chat( | |
| prompt_message, | |
| temperature: float = 0.5, | |
| top_p: float = 0.95, | |
| top_k: int = 20, | |
| constrained_generation: str = "False", | |
| ) -> str: | |
| """ | |
| chat - the main function for the chatbot. This is the function that is called when the user | |
| :param _type_ prompt_message: the message to send to the model | |
| :param float temperature: the temperature value for the model, defaults to 0.6 | |
| :param float top_p: the top_p value for the model, defaults to 0.95 | |
| :param int top_k: the top_k value for the model, defaults to 25 | |
| :param bool constrained_generation: whether to use constrained generation or not, defaults to False | |
| :return str: the response from the model | |
| """ | |
| history = [] | |
| response = ask_gpt( | |
| message=prompt_message, | |
| chat_pipe=my_chatbot, | |
| top_p=top_p, | |
| top_k=top_k, | |
| temperature=temperature, | |
| constrained_generation=constrained_generation, | |
| ) | |
| history = [prompt_message, response] | |
| html = "" | |
| for item in history: | |
| html += f"<b>{item}</b> <br><br>" | |
| html += "" | |
| return html | |
| def ask_gpt( | |
| message: str, | |
| chat_pipe, | |
| speaker="person alpha", | |
| responder="person beta", | |
| min_length=12, | |
| max_length=48, | |
| top_p=0.95, | |
| top_k=25, | |
| temperature=0.5, | |
| constrained_generation=False, | |
| max_input_length=128, | |
| ) -> str: | |
| """ | |
| ask_gpt - helper function that asks the GPT model a question and returns the response | |
| :param str message: the question to ask the model | |
| :param chat_pipe: the pipeline object for the model, created by the pipeline() function | |
| :param str speaker: the name of the speaker, defaults to "person alpha" | |
| :param str responder: the name of the responder, defaults to "person beta" | |
| :param int min_length: the minimum length of the response, defaults to 12 | |
| :param int max_length: the maximum length of the response, defaults to 64 | |
| :param float top_p: the top_p value for the model, defaults to 0.95 | |
| :param int top_k: the top_k value for the model, defaults to 25 | |
| :param float temperature: the temperature value for the model, defaults to 0.6 | |
| :param bool constrained_generation: whether to use constrained generation or not, defaults to False | |
| :return str: the response from the model | |
| """ | |
| st = time.perf_counter() | |
| prompt = clean(message) # clean user input | |
| prompt = prompt.strip() # get rid of any extra whitespace | |
| in_len = len(chat_pipe.tokenizer(prompt).input_ids) | |
| if in_len > max_input_length: | |
| # truncate to last max_input_length tokens | |
| tokens = chat_pipe.tokenizer(prompt).input_ids | |
| trunc_tokens = tokens[-max_input_length:] | |
| prompt = chat_pipe.tokenizer.decode(trunc_tokens) | |
| print(f"truncated prompt to {len(trunc_tokens)} tokens, input length: {in_len}") | |
| logging.info(f"prompt: {prompt}") | |
| resp = discussion( | |
| prompt_text=prompt, | |
| pipeline=chat_pipe, | |
| speaker=speaker, | |
| responder=responder, | |
| top_p=top_p, | |
| top_k=top_k, | |
| temperature=temperature, | |
| max_length=max_length, | |
| min_length=min_length, | |
| constrained_beam_search=constrained_generation, | |
| ) | |
| gpt_et = time.perf_counter() | |
| gpt_rt = round(gpt_et - st, 2) | |
| rawtxt = resp["out_text"] | |
| # check for proper nouns | |
| if basic_sc: | |
| cln_resp = symspeller(rawtxt, sym_checker=basic_spell) | |
| else: | |
| cln_resp = synthesize_grammar(corrector=grammarbot, message=rawtxt) | |
| bot_resp_a = corr(remove_repeated_words(cln_resp)) | |
| bot_resp = fix_punct_spacing(bot_resp_a) | |
| corr_rt = round(time.perf_counter() - gpt_et, 4) | |
| print(f"{gpt_rt + corr_rt} to respond, {gpt_rt} GPT, {corr_rt} for correction\n") | |
| return remove_trailing_punctuation(bot_resp) | |
| def get_parser(): | |
| """ | |
| get_parser - a helper function for the argparse module | |
| """ | |
| parser = argparse.ArgumentParser( | |
| description="submit a question, GPT model responds" | |
| ) | |
| parser.add_argument( | |
| "-m", | |
| "--model", | |
| required=False, | |
| type=str, | |
| default="ethzanalytics/ai-msgbot-gpt2-XL", # default model | |
| help="the model to use for the chatbot on https://huggingface.co/models OR a path to a local model", | |
| ) | |
| parser.add_argument( | |
| "-gm", | |
| "--gram-model", | |
| required=False, | |
| type=str, | |
| default="pszemraj/grammar-synthesis-base", | |
| help="text2text generation model ID from huggingface for the model to correct grammar", | |
| ) | |
| parser.add_argument( | |
| "--basic-sc", | |
| required=False, | |
| default=False, | |
| action="store_true", | |
| help="use symspell (statistical spelling correction) instead of neural spell correction", | |
| ) | |
| parser.add_argument( | |
| "--test", | |
| action="store_true", | |
| default=False, | |
| help="load the smallest model for simple testing (ethzanalytics/distilgpt2-tiny-conversational)", | |
| ) | |
| parser.add_argument( | |
| "--verbose", | |
| action="store_true", | |
| default=False, | |
| help="turn on verbose printing", | |
| ) | |
| parser.add_argument( | |
| "-q", | |
| "--quiet", | |
| dest="loglevel", | |
| help="set loglevel to WARNING (reduce output)", | |
| action="store_const", | |
| const=logging.WARNING, | |
| ) | |
| parser.add_argument( | |
| "-vv", | |
| "--very-verbose", | |
| dest="loglevel", | |
| help="set loglevel to DEBUG", | |
| action="store_const", | |
| const=logging.DEBUG, | |
| ) | |
| return parser | |
| if __name__ == "__main__": | |
| args = get_parser().parse_args() | |
| loglevel = args.loglevel or logging.INFO | |
| setup_logging(loglevel) | |
| logging.info("\n\n\nStarting app.py\n\n\n") | |
| logging.info(f"args: {args}") | |
| default_model = str(args.model) | |
| if args.test: | |
| logging.info("loading the smallest model for testing") | |
| default_model = "ethzanalytics/distilgpt2-tiny-conversational" | |
| model_loc = Path(default_model) # if the model is a path, use it | |
| basic_sc = args.basic_sc # whether to use the baseline spellchecker | |
| gram_model = str(args.gram_model) | |
| device = 0 if torch.cuda.is_available() else -1 | |
| logging.info(f"CUDA avail is {torch.cuda.is_available()}") | |
| my_chatbot = ( | |
| pipeline("text-generation", model=model_loc.resolve(), device=device) | |
| if model_loc.exists() and model_loc.is_dir() | |
| else pipeline("text-generation", model=default_model, device=device) | |
| ) # if the model is a name, use it. stays on CPU if no GPU available | |
| logging.info(f"using model {my_chatbot.model}") | |
| if basic_sc: | |
| logging.info("Using the baseline spellchecker") | |
| basic_spell = build_symspell_obj() | |
| else: | |
| logging.info("using neural spell checker") | |
| grammarbot = pipeline("text2text-generation", gram_model, device=device) | |
| logging.debug(f"using model stored here: \n {model_loc} \n") | |
| iface = gr.Interface( | |
| chat, | |
| inputs=[ | |
| Textbox( | |
| default="Why is everyone here eating chocolate cake?", | |
| label="prompt_message", | |
| placeholder="Start a conversation with the bot", | |
| lines=2, | |
| ), | |
| Slider( | |
| minimum=0.0, maximum=1.0, step=0.05, default=0.4, label="temperature" | |
| ), | |
| Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.95, label="top_p"), | |
| Slider(minimum=0, maximum=100, step=5, default=20, label="top_k"), | |
| Radio( | |
| choices=[True, False], | |
| default=False, | |
| label="constrained_generation", | |
| ), | |
| ], | |
| outputs="html", | |
| examples_per_page=8, | |
| examples=[ | |
| ["Point Break or Bad Boys II?", 0.75, 0.95, 50, False], | |
| ["So... you're saying this wasn't an accident?", 0.6, 0.95, 40, False], | |
| ["Hi, my name is Reginald", 0.6, 0.95, 100, False], | |
| ["Happy birthday!", 0.9, 0.95, 50, False], | |
| ["I have a question, can you help me?", 0.6, 0.95, 50, False], | |
| ["Do you know a joke?", 0.8, 0.85, 50, False], | |
| ["Will you marry me?", 0.9, 0.95, 100, False], | |
| ["Are you single?", 0.95, 0.95, 100, False], | |
| ["Do you like people?", 0.7, 0.95, 25, False], | |
| ["You never took a shortcut before?", 0.7, 0.95, 100, False], | |
| ], | |
| title=f"GPT Chatbot Demo: {default_model} Model", | |
| description=f"A Demo of a Chatbot trained for conversation with humans. Size XL= 1.5B parameters.\n\n" | |
| "**Important Notes & About:**\n\n" | |
| "You can find a link to the model card **[here](https://huggingface.co/ethzanalytics/ai-msgbot-gpt2-XL-dialogue)**\n\n" | |
| "1. responses can take up to 60 seconds to respond sometimes, patience is a virtue.\n" | |
| "2. the model was trained on several different datasets. fact-check responses instead of regarding as a true statement.\n" | |
| "3. Try adjusting the **[generation parameters](https://huggingface.co/blog/how-to-generate)** to get a better understanding of how they work!\n" | |
| "4. New - try using [constrained beam search](https://huggingface.co/blog/constrained-beam-search) decoding to generate more coherent responses. _(experimental, feedback welcome!)_\n", | |
| css=""" | |
| .chatbox {display:flex;flex-direction:row} | |
| .user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%} | |
| .user_msg {background-color:cornflowerblue;color:white;align-self:start} | |
| .resp_msg {background-color:lightgray;align-self:self-end} | |
| """, | |
| allow_flagging="never", | |
| theme="dark", | |
| ) | |
| # launch the gradio interface and start the server | |
| iface.launch( | |
| enable_queue=True, | |
| ) | |