import gradio as gr import torch import re, os, warnings from langchain import PromptTemplate, LLMChain from langchain.llms.base import LLM from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig from peft import LoraConfig, get_peft_model, PeftConfig, PeftModel warnings.filterwarnings("ignore") # initialize and load PEFT model and tokenizer def init_model_and_tokenizer(PEFT_MODEL): config = PeftConfig.from_pretrained(PEFT_MODEL) bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.float16, ) peft_base_model = AutoModelForCausalLM.from_pretrained( config.base_model_name_or_path, return_dict=True, quantization_config=bnb_config, device_map="auto", trust_remote_code=True, ) peft_model = PeftModel.from_pretrained(peft_base_model, PEFT_MODEL) peft_tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) peft_tokenizer.pad_token = peft_tokenizer.eos_token return peft_model, peft_tokenizer # custom LLM chain to generate answer from PEFT model for each query def init_llm_chain(peft_model, peft_tokenizer): class CustomLLM(LLM): def _call(self, prompt: str, stop=None, run_manager=None) -> str: device = "cuda:0" peft_encoding = peft_tokenizer(prompt, return_tensors="pt").to(device) peft_outputs = peft_model.generate(input_ids=peft_encoding.input_ids, generation_config=GenerationConfig(max_new_tokens=256, pad_token_id = peft_tokenizer.eos_token_id, \ eos_token_id = peft_tokenizer.eos_token_id, attention_mask = peft_encoding.attention_mask, \ temperature=0.4, top_p=0.6, repetition_penalty=1.3, num_return_sequences=1,)) peft_text_output = peft_tokenizer.decode(peft_outputs[0], skip_special_tokens=True) return peft_text_output @property def _llm_type(self) -> str: return "custom" llm = CustomLLM() template = """Answer the following question truthfully. If you don't know the answer, respond 'Sorry, I don't know the answer to this question.'. If the question is too complex, respond 'Kindly, consult a psychiatrist for further queries.'. Example Format: : question here : answer here Begin! : {query} :""" prompt = PromptTemplate(template=template, input_variables=["query"]) llm_chain = LLMChain(prompt=prompt, llm=llm) return llm_chain def user(user_message, history): return "", history + [[user_message, None]] def bot(history): if len(history) >= 2: query = history[-2][0] + "\n" + history[-2][1] + "\nHere, is the next QUESTION: " + history[-1][0] else: query = history[-1][0] bot_message = llm_chain.run(query) bot_message = post_process_chat(bot_message) history[-1][1] = "" history[-1][1] += bot_message return history def post_process_chat(bot_message): try: bot_message = re.findall(r":.*?Begin!", bot_message, re.DOTALL)[1] except IndexError: pass bot_message = re.split(r'\:?\s?', bot_message)[-1].split("Begin!")[0] bot_message = re.sub(r"^(.*?\.)(?=\n|$)", r"\1", bot_message, flags=re.DOTALL) try: bot_message = re.search(r"(.*\.)", bot_message, re.DOTALL).group(1) except AttributeError: pass bot_message = re.sub(r"\n\d.$", "", bot_message) bot_message = re.split(r"(Goodbye|Take care|Best Wishes)", bot_message, flags=re.IGNORECASE)[0].strip() bot_message = bot_message.replace("\n\n", "\n") return bot_message model = "heliosbrahma/falcon-7b-sharded-bf16-finetuned-mental-health-conversational" peft_model, peft_tokenizer = init_model_and_tokenizer(PEFT_MODEL = model) with gr.Blocks() as demo: gr.HTML("""Welcome to Mental Health Conversational AI""") gr.Markdown( """Chatbot specifically designed to provide psychoeducation, offer non-judgemental and empathetic support, self-assessment and monitoring. Get instant response for any mental health related queries. If the chatbot seems you need external support, then it will respond appropriately.""" ) chatbot = gr.Chatbot() query = gr.Textbox(label="Type your query here, then press 'enter' and scroll up for response") clear = gr.Button(value="Clear Chat History!") clear.style(size="sm") llm_chain = init_llm_chain(peft_model, peft_tokenizer) query.submit(user, [query, chatbot], [query, chatbot], queue=False).then(bot, chatbot, chatbot) clear.click(lambda: None, None, chatbot, queue=False) demo.queue().launch(inline=False)