|
|
|
"""gradio_chatbot_app.ipynb |
|
|
|
Automatically generated by Colab. |
|
|
|
Original file is located at |
|
https://colab.research.google.com/#fileId=https%3A//huggingface.co/spaces/ZamiSanj/therapx/blob/main/gradio_chatbot_app.ipynb |
|
""" |
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
import torch |
|
import re, os, warnings |
|
from langchain import PromptTemplate, LLMChain |
|
from langchain.llms.base import LLM |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig |
|
from peft import LoraConfig, get_peft_model, PeftConfig, PeftModel |
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
def init_model_and_tokenizer(PEFT_MODEL): |
|
config = PeftConfig.from_pretrained(PEFT_MODEL) |
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_compute_dtype=torch.float16, |
|
) |
|
|
|
peft_base_model = AutoModelForCausalLM.from_pretrained( |
|
config.base_model_name_or_path, |
|
return_dict=True, |
|
quantization_config=bnb_config, |
|
device_map="auto", |
|
trust_remote_code=True, |
|
) |
|
|
|
peft_model = PeftModel.from_pretrained(peft_base_model, PEFT_MODEL) |
|
|
|
peft_tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) |
|
peft_tokenizer.pad_token = peft_tokenizer.eos_token |
|
|
|
return peft_model, peft_tokenizer |
|
|
|
|
|
def init_llm_chain(peft_model, peft_tokenizer): |
|
class CustomLLM(LLM): |
|
def _call(self, prompt: str, stop=None, run_manager=None) -> str: |
|
device = "cuda:0" |
|
peft_encoding = peft_tokenizer(prompt, return_tensors="pt").to(device) |
|
peft_outputs = peft_model.generate(input_ids=peft_encoding.input_ids, generation_config=GenerationConfig(max_new_tokens=256, pad_token_id = peft_tokenizer.eos_token_id, \ |
|
eos_token_id = peft_tokenizer.eos_token_id, attention_mask = peft_encoding.attention_mask, \ |
|
temperature=0.4, top_p=0.6, repetition_penalty=1.3, num_return_sequences=1,)) |
|
peft_text_output = peft_tokenizer.decode(peft_outputs[0], skip_special_tokens=True) |
|
return peft_text_output |
|
|
|
@property |
|
def _llm_type(self) -> str: |
|
return "custom" |
|
|
|
llm = CustomLLM() |
|
|
|
template = """Answer the following question truthfully. |
|
If you don't know the answer, respond 'Sorry, I don't know the answer to this question.'. |
|
If the question is too complex, respond 'Kindly, consult a psychiatrist for further queries.'. |
|
|
|
Example Format: |
|
<HUMAN>: question here |
|
<ASSISTANT>: answer here |
|
|
|
Begin! |
|
|
|
<HUMAN>: {query} |
|
<ASSISTANT>:""" |
|
|
|
prompt = PromptTemplate(template=template, input_variables=["query"]) |
|
llm_chain = LLMChain(prompt=prompt, llm=llm) |
|
|
|
return llm_chain |
|
|
|
def user(user_message, history): |
|
return "", history + [[user_message, None]] |
|
|
|
def bot(history): |
|
if len(history) >= 2: |
|
query = history[-2][0] + "\n" + history[-2][1] + "\nHere, is the next QUESTION: " + history[-1][0] |
|
else: |
|
query = history[-1][0] |
|
|
|
bot_message = llm_chain.run(query) |
|
bot_message = post_process_chat(bot_message) |
|
|
|
history[-1][1] = "" |
|
history[-1][1] += bot_message |
|
return history |
|
|
|
def post_process_chat(bot_message): |
|
try: |
|
bot_message = re.findall(r"<ASSISTANT>:.*?Begin!", bot_message, re.DOTALL)[1] |
|
except IndexError: |
|
pass |
|
|
|
bot_message = re.split(r'<ASSISTANT>\:?\s?', bot_message)[-1].split("Begin!")[0] |
|
|
|
bot_message = re.sub(r"^(.*?\.)(?=\n|$)", r"\1", bot_message, flags=re.DOTALL) |
|
try: |
|
bot_message = re.search(r"(.*\.)", bot_message, re.DOTALL).group(1) |
|
except AttributeError: |
|
pass |
|
|
|
bot_message = re.sub(r"\n\d.$", "", bot_message) |
|
bot_message = re.split(r"(Goodbye|Take care|Best Wishes)", bot_message, flags=re.IGNORECASE)[0].strip() |
|
bot_message = bot_message.replace("\n\n", "\n") |
|
|
|
return bot_message |
|
|
|
model = "heliosbrahma/falcon-7b-sharded-bf16-finetuned-mental-health-conversational" |
|
peft_model, peft_tokenizer = init_model_and_tokenizer(PEFT_MODEL = model) |
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML("""<h1>TherapX</h1>""") |
|
gr.Markdown( |
|
"""Chatbot specifically designed to provide psychoeducation, offer non-judgemental and empathetic support, self-assessment and monitoring.<br> |
|
Get instant response for any mental health related queries. If the chatbot seems you need external support, then it will respond appropriately.<br>""" |
|
) |
|
|
|
chatbot = gr.Chatbot() |
|
query = gr.Textbox(label="Type your query here, then press 'enter' and scroll up for response") |
|
clear = gr.Button(value="Clear Chat History!") |
|
clear.style(size="sm") |
|
|
|
llm_chain = init_llm_chain(peft_model, peft_tokenizer) |
|
|
|
query.submit(user, [query, chatbot], [query, chatbot], queue=False).then(bot, chatbot, chatbot) |
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
demo.queue().launch() |
|
|
|
|
|
|
|
|