File size: 9,008 Bytes
7149046
cda715f
 
 
 
 
7149046
30f266c
cda715f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7149046
cda715f
 
 
 
 
 
7149046
cda715f
 
 
 
7149046
cda715f
7149046
cda715f
 
7149046
cda715f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, GenerationConfig
from peft import LoraConfig, get_peft_model, PeftConfig, PeftModel, prepare_model_for_kbit_training
from trl import SFTTrainer
import warnings
warnings.filterwarnings("ignore")

data = load_dataset("heliosbrahma/mental_health_chatbot_dataset")
model_name = "vilsonrodrigues/falcon-7b-instruct-sharded" # sharded falcon-7b model

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,            # load model in 4-bit precision
    bnb_4bit_quant_type="nf4",    # pre-trained model should be quantized in 4-bit NF format
    bnb_4bit_use_double_quant=True, # Using double quantization as mentioned in QLoRA paper
    bnb_4bit_compute_dtype=torch.bf16, # During computation, pre-trained model should be loaded in BF16 format
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config, # Use bitsandbytes config
    device_map="auto",  # Specifying device_map="auto" so that HF Accelerate will determine which GPU to put each layer of the model on
    trust_remote_code=True, # Set trust_remote_code=True to use falcon-7b model with custom code
)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) # Set trust_remote_code=True
tokenizer.pad_token = tokenizer.eos_token # Setting pad_token same as eos_token
model = prepare_model_for_kbit_training(model)

lora_alpha = 32 # scaling factor for the weight matrices
lora_dropout = 0.05 # dropout probability of the LoRA layers
lora_rank = 16 # dimension of the low-rank matrices

peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_rank,
    bias="none",  # setting to 'none' for only training weight params instead of biases
    task_type="CAUSAL_LM",
    target_modules=[         # Setting names of modules in falcon-7b model that we want to apply LoRA to
        "query_key_value",
        "dense",
        "dense_h_to_4h",
        "dense_4h_to_h",
    ]
)

peft_model = get_peft_model(model, peft_config)

output_dir = "./falcon-7b-sharded-fp16-finetuned-mental-health-conversational"
per_device_train_batch_size = 16 # reduce batch size by 2x if out-of-memory error
gradient_accumulation_steps = 4  # increase gradient accumulation steps by 2x if batch size is reduced
optim = "paged_adamw_32bit" # activates the paging for better memory management
save_strategy="steps" # checkpoint save strategy to adopt during training
save_steps = 10 # number of updates steps before two checkpoint saves
logging_steps = 10  # number of update steps between two logs if logging_strategy="steps"
learning_rate = 2e-4  # learning rate for AdamW optimizer
max_grad_norm = 0.3 # maximum gradient norm (for gradient clipping)
max_steps = 70        # training will happen for 70 steps
warmup_ratio = 0.03 # number of steps used for a linear warmup from 0 to learning_rate
lr_scheduler_type = "cosine"  # learning rate scheduler

training_arguments = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    bf16=True,
    max_grad_norm=max_grad_norm,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    group_by_length=True,
    lr_scheduler_type=lr_scheduler_type,
    push_to_hub=True,
)
trainer = SFTTrainer(
    model=peft_model,
    train_dataset=data['train'],
    peft_config=peft_config,
    dataset_text_field="text",
    ac=1024,
    tokenizer=tokenizer,
    args=training_arguments,
)

# upcasting the layer norms in torch.bfloat16 for more stable training
for name, module in trainer.model.named_modules():
    if "norm" in name:
        module = module.to(torch.bfloat16)

peft_model.config.use_cache = False
trainer.train()
trainer.push_to_hub("therapx")

# import gradio as gr
# import torch
# import re, os, warnings
# from langchain import PromptTemplate, LLMChain
# from langchain.llms.base import LLM
# from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
# from peft import LoraConfig, get_peft_model, PeftConfig, PeftModel
# warnings.filterwarnings("ignore")

# def init_model_and_tokenizer(PEFT_MODEL):
#   config = PeftConfig.from_pretrained(PEFT_MODEL)
#   bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_use_double_quant=True,
#     bnb_4bit_compute_dtype=torch.float16,
#   )

#   peft_base_model = AutoModelForCausalLM.from_pretrained(
#     config.base_model_name_or_path,
#     return_dict=True,
#     quantization_config=bnb_config,
#     device_map="auto",
#     trust_remote_code=True,
#   )

#   peft_model = PeftModel.from_pretrained(peft_base_model, PEFT_MODEL)

#   peft_tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
#   peft_tokenizer.pad_token = peft_tokenizer.eos_token

#   return peft_model, peft_tokenizer

# def init_llm_chain(peft_model, peft_tokenizer):
#     class CustomLLM(LLM):
#         def _call(self, prompt: str, stop=None, run_manager=None) -> str:
#             device = "cuda:0"
#             peft_encoding = peft_tokenizer(prompt, return_tensors="pt").to(device)
#             peft_outputs = peft_model.generate(input_ids=peft_encoding.input_ids, generation_config=GenerationConfig(max_new_tokens=256, pad_token_id = peft_tokenizer.eos_token_id, \
#                                                                                                                      eos_token_id = peft_tokenizer.eos_token_id, attention_mask = peft_encoding.attention_mask, \
#                                                                                                                      temperature=0.4, top_p=0.6, repetition_penalty=1.3, num_return_sequences=1,))
#             peft_text_output = peft_tokenizer.decode(peft_outputs[0], skip_special_tokens=True)
#             return peft_text_output

#         @property
#         def _llm_type(self) -> str:
#             return "custom"

#     llm = CustomLLM()

#     template = """Answer the following question truthfully.
#     If you don't know the answer, respond 'Sorry, I don't know the answer to this question.'.
#     If the question is too complex, respond 'Kindly, consult a psychiatrist for further queries.'.

#     Example Format:
#     <HUMAN>: question here
#     <ASSISTANT>: answer here

#     Begin!

#     <HUMAN>: {query}
#     <ASSISTANT>:"""

#     prompt = PromptTemplate(template=template, input_variables=["query"])
#     llm_chain = LLMChain(prompt=prompt, llm=llm)

#     return llm_chain

# def user(user_message, history):
#   return "", history + [[user_message, None]]

# def bot(history):
#   if len(history) >= 2:
#     query = history[-2][0] + "\n" + history[-2][1] + "\nHere, is the next QUESTION: " + history[-1][0]
#   else:
#     query = history[-1][0]

#   bot_message = llm_chain.run(query)
#   bot_message = post_process_chat(bot_message)

#   history[-1][1] = ""
#   history[-1][1] += bot_message
#   return history

# def post_process_chat(bot_message):
#   try:
#     bot_message = re.findall(r"<ASSISTANT>:.*?Begin!", bot_message, re.DOTALL)[1]
#   except IndexError:
#     pass

#   bot_message = re.split(r'<ASSISTANT>\:?\s?', bot_message)[-1].split("Begin!")[0]

#   bot_message = re.sub(r"^(.*?\.)(?=\n|$)", r"\1", bot_message, flags=re.DOTALL)
#   try:
#     bot_message = re.search(r"(.*\.)", bot_message, re.DOTALL).group(1)
#   except AttributeError:
#     pass

#   bot_message = re.sub(r"\n\d.$", "", bot_message)
#   bot_message = re.split(r"(Goodbye|Take care|Best Wishes)", bot_message, flags=re.IGNORECASE)[0].strip()
#   bot_message = bot_message.replace("\n\n", "\n")

#   return bot_message

# model = "heliosbrahma/falcon-7b-sharded-bf16-finetuned-mental-health-conversational"
# peft_model, peft_tokenizer = init_model_and_tokenizer(PEFT_MODEL = model)

# with gr.Blocks() as interface:
#     gr.HTML("""<h1>Welcome to Mental Health Conversational AI</h1>""")
#     gr.Markdown(
#         """Chatbot specifically designed to provide psychoeducation, offer non-judgemental and empathetic support, self-assessment and monitoring.<br>
#         Get instant response for any mental health related queries. If the chatbot seems you need external support, then it will respond appropriately.<br>"""
#     )

#     chatbot = gr.Chatbot()
#     query = gr.Textbox(label="Type your query here, then press 'enter' and scroll up for response")
#     clear = gr.Button(value="Clear Chat History!")
#     clear.style(size="sm")

#     llm_chain = init_llm_chain(peft_model, peft_tokenizer)

#     query.submit(user, [query, chatbot], [query, chatbot], queue=False).then(bot, chatbot, chatbot)
#     clear.click(lambda: None, None, chatbot, queue=False)

# interface.queue().launch()