Spaces:
Runtime error
Runtime error
!pip install gradio transformers langchain -Uqqq | |
!pip install accelerate bitsandbytes einops git+https://github.com/huggingface/peft.git -Uqqq | |
import gradio as gr | |
import torch | |
import re, os, warnings | |
from langchain import PromptTemplate, LLMChain | |
from langchain.llms.base import LLM | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig | |
from peft import LoraConfig, get_peft_model, PeftConfig, PeftModel | |
warnings.filterwarnings("ignore") | |
# initialize and load PEFT model and tokenizer | |
def init_model_and_tokenizer(PEFT_MODEL): | |
config = PeftConfig.from_pretrained(PEFT_MODEL) | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
) | |
peft_base_model = AutoModelForCausalLM.from_pretrained( | |
config.base_model_name_or_path, | |
return_dict=True, | |
quantization_config=bnb_config, | |
device_map="auto", | |
trust_remote_code=True, | |
) | |
peft_model = PeftModel.from_pretrained(peft_base_model, PEFT_MODEL) | |
peft_tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) | |
peft_tokenizer.pad_token = peft_tokenizer.eos_token | |
return peft_model, peft_tokenizer | |
# custom LLM chain to generate answer from PEFT model for each query | |
def init_llm_chain(peft_model, peft_tokenizer): | |
class CustomLLM(LLM): | |
def _call(self, prompt: str, stop=None, run_manager=None) -> str: | |
device = "cuda:0" | |
peft_encoding = peft_tokenizer(prompt, return_tensors="pt").to(device) | |
peft_outputs = peft_model.generate(input_ids=peft_encoding.input_ids, generation_config=GenerationConfig(max_new_tokens=256, pad_token_id = peft_tokenizer.eos_token_id, \ | |
eos_token_id = peft_tokenizer.eos_token_id, attention_mask = peft_encoding.attention_mask, \ | |
temperature=0.4, top_p=0.6, repetition_penalty=1.3, num_return_sequences=1,)) | |
peft_text_output = peft_tokenizer.decode(peft_outputs[0], skip_special_tokens=True) | |
return peft_text_output | |
def _llm_type(self) -> str: | |
return "custom" | |
llm = CustomLLM() | |
template = """Answer the following question truthfully. | |
If you don't know the answer, respond 'Sorry, I don't know the answer to this question.'. | |
If the question is too complex, respond 'Kindly, consult a psychiatrist for further queries.'. | |
Example Format: | |
: question here | |
: answer here | |
Begin! | |
: {query} | |
:""" | |
prompt = PromptTemplate(template=template, input_variables=["query"]) | |
llm_chain = LLMChain(prompt=prompt, llm=llm) | |
return llm_chain | |
def user(user_message, history): | |
return "", history + [[user_message, None]] | |
def bot(history): | |
if len(history) >= 2: | |
query = history[-2][0] + "\n" + history[-2][1] + "\nHere, is the next QUESTION: " + history[-1][0] | |
else: | |
query = history[-1][0] | |
bot_message = llm_chain.run(query) | |
bot_message = post_process_chat(bot_message) | |
history[-1][1] = "" | |
history[-1][1] += bot_message | |
return history | |
def post_process_chat(bot_message): | |
try: | |
bot_message = re.findall(r":.*?Begin!", bot_message, re.DOTALL)[1] | |
except IndexError: | |
pass | |
bot_message = re.split(r'\:?\s?', bot_message)[-1].split("Begin!")[0] | |
bot_message = re.sub(r"^(.*?\.)(?=\n|$)", r"\1", bot_message, flags=re.DOTALL) | |
try: | |
bot_message = re.search(r"(.*\.)", bot_message, re.DOTALL).group(1) | |
except AttributeError: | |
pass | |
bot_message = re.sub(r"\n\d.$", "", bot_message) | |
bot_message = re.split(r"(Goodbye|Take care|Best Wishes)", bot_message, flags=re.IGNORECASE)[0].strip() | |
bot_message = bot_message.replace("\n\n", "\n") | |
return bot_message | |
model = "heliosbrahma/falcon-7b-sharded-bf16-finetuned-mental-health-conversational" | |
peft_model, peft_tokenizer = init_model_and_tokenizer(PEFT_MODEL = model) | |
with gr.Blocks() as demo: | |
gr.HTML("""Welcome to Mental Health Conversational AI""") | |
gr.Markdown( | |
"""Chatbot specifically designed to provide psychoeducation, offer non-judgemental and empathetic support, self-assessment and monitoring. | |
Get instant response for any mental health related queries. If the chatbot seems you need external support, then it will respond appropriately.""" | |
) | |
chatbot = gr.Chatbot() | |
query = gr.Textbox(label="Type your query here, then press 'enter' and scroll up for response") | |
clear = gr.Button(value="Clear Chat History!") | |
clear.style(size="sm") | |
llm_chain = init_llm_chain(peft_model, peft_tokenizer) | |
query.submit(user, [query, chatbot], [query, chatbot], queue=False).then(bot, chatbot, chatbot) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
demo.queue().launch(inline=False) |