prathik31 commited on
Commit
aeccd56
β€’
1 Parent(s): f9f1b30

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -0
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ !pip install gradio transformers langchain -Uqqq
2
+ !pip install accelerate bitsandbytes einops git+https://github.com/huggingface/peft.git -Uqqq
3
+
4
+
5
+ import gradio as gr
6
+ import torch
7
+ import re, os, warnings
8
+ from langchain import PromptTemplate, LLMChain
9
+ from langchain.llms.base import LLM
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
11
+ from peft import LoraConfig, get_peft_model, PeftConfig, PeftModel
12
+ warnings.filterwarnings("ignore")
13
+
14
+ # initialize and load PEFT model and tokenizer
15
+ def init_model_and_tokenizer(PEFT_MODEL):
16
+ config = PeftConfig.from_pretrained(PEFT_MODEL)
17
+ bnb_config = BitsAndBytesConfig(
18
+ load_in_4bit=True,
19
+ bnb_4bit_quant_type="nf4",
20
+ bnb_4bit_use_double_quant=True,
21
+ bnb_4bit_compute_dtype=torch.float16,
22
+ )
23
+
24
+ peft_base_model = AutoModelForCausalLM.from_pretrained(
25
+ config.base_model_name_or_path,
26
+ return_dict=True,
27
+ quantization_config=bnb_config,
28
+ device_map="auto",
29
+ trust_remote_code=True,
30
+ )
31
+
32
+ peft_model = PeftModel.from_pretrained(peft_base_model, PEFT_MODEL)
33
+
34
+ peft_tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
35
+ peft_tokenizer.pad_token = peft_tokenizer.eos_token
36
+
37
+ return peft_model, peft_tokenizer
38
+
39
+
40
+ # custom LLM chain to generate answer from PEFT model for each query
41
+ def init_llm_chain(peft_model, peft_tokenizer):
42
+ class CustomLLM(LLM):
43
+ def _call(self, prompt: str, stop=None, run_manager=None) -> str:
44
+ device = "cuda:0"
45
+ peft_encoding = peft_tokenizer(prompt, return_tensors="pt").to(device)
46
+ peft_outputs = peft_model.generate(input_ids=peft_encoding.input_ids, generation_config=GenerationConfig(max_new_tokens=256, pad_token_id = peft_tokenizer.eos_token_id, \
47
+ eos_token_id = peft_tokenizer.eos_token_id, attention_mask = peft_encoding.attention_mask, \
48
+ temperature=0.4, top_p=0.6, repetition_penalty=1.3, num_return_sequences=1,))
49
+ peft_text_output = peft_tokenizer.decode(peft_outputs[0], skip_special_tokens=True)
50
+ return peft_text_output
51
+
52
+ @property
53
+ def _llm_type(self) -> str:
54
+ return "custom"
55
+
56
+ llm = CustomLLM()
57
+
58
+ template = """Answer the following question truthfully.
59
+ If you don't know the answer, respond 'Sorry, I don't know the answer to this question.'.
60
+ If the question is too complex, respond 'Kindly, consult a psychiatrist for further queries.'.
61
+
62
+ Example Format:
63
+ : question here
64
+ : answer here
65
+
66
+ Begin!
67
+
68
+ : {query}
69
+ :"""
70
+
71
+ prompt = PromptTemplate(template=template, input_variables=["query"])
72
+ llm_chain = LLMChain(prompt=prompt, llm=llm)
73
+
74
+ return llm_chain
75
+
76
+ def user(user_message, history):
77
+ return "", history + [[user_message, None]]
78
+
79
+ def bot(history):
80
+ if len(history) >= 2:
81
+ query = history[-2][0] + "\n" + history[-2][1] + "\nHere, is the next QUESTION: " + history[-1][0]
82
+ else:
83
+ query = history[-1][0]
84
+
85
+ bot_message = llm_chain.run(query)
86
+ bot_message = post_process_chat(bot_message)
87
+
88
+ history[-1][1] = ""
89
+ history[-1][1] += bot_message
90
+ return history
91
+
92
+ def post_process_chat(bot_message):
93
+ try:
94
+ bot_message = re.findall(r":.*?Begin!", bot_message, re.DOTALL)[1]
95
+ except IndexError:
96
+ pass
97
+
98
+ bot_message = re.split(r'\:?\s?', bot_message)[-1].split("Begin!")[0]
99
+
100
+ bot_message = re.sub(r"^(.*?\.)(?=\n|$)", r"\1", bot_message, flags=re.DOTALL)
101
+ try:
102
+ bot_message = re.search(r"(.*\.)", bot_message, re.DOTALL).group(1)
103
+ except AttributeError:
104
+ pass
105
+
106
+ bot_message = re.sub(r"\n\d.$", "", bot_message)
107
+ bot_message = re.split(r"(Goodbye|Take care|Best Wishes)", bot_message, flags=re.IGNORECASE)[0].strip()
108
+ bot_message = bot_message.replace("\n\n", "\n")
109
+
110
+ return bot_message
111
+
112
+ model = "heliosbrahma/falcon-7b-sharded-bf16-finetuned-mental-health-conversational"
113
+ peft_model, peft_tokenizer = init_model_and_tokenizer(PEFT_MODEL = model)
114
+
115
+
116
+ with gr.Blocks() as demo:
117
+ gr.HTML("""Welcome to Mental Health Conversational AI""")
118
+ gr.Markdown(
119
+ """Chatbot specifically designed to provide psychoeducation, offer non-judgemental and empathetic support, self-assessment and monitoring.
120
+ Get instant response for any mental health related queries. If the chatbot seems you need external support, then it will respond appropriately."""
121
+ )
122
+
123
+ chatbot = gr.Chatbot()
124
+ query = gr.Textbox(label="Type your query here, then press 'enter' and scroll up for response")
125
+ clear = gr.Button(value="Clear Chat History!")
126
+ clear.style(size="sm")
127
+
128
+ llm_chain = init_llm_chain(peft_model, peft_tokenizer)
129
+
130
+ query.submit(user, [query, chatbot], [query, chatbot], queue=False).then(bot, chatbot, chatbot)
131
+ clear.click(lambda: None, None, chatbot, queue=False)
132
+
133
+ demo.queue().launch(inline=False)