File size: 9,008 Bytes
7149046 cda715f 7149046 30f266c cda715f 7149046 cda715f 7149046 cda715f 7149046 cda715f 7149046 cda715f 7149046 cda715f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, GenerationConfig
from peft import LoraConfig, get_peft_model, PeftConfig, PeftModel, prepare_model_for_kbit_training
from trl import SFTTrainer
import warnings
data = load_dataset("heliosbrahma/mental_health_chatbot_dataset")
model_name = "vilsonrodrigues/falcon-7b-instruct-sharded" # sharded falcon-7b model
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, # load model in 4-bit precision
bnb_4bit_quant_type="nf4", # pre-trained model should be quantized in 4-bit NF format
bnb_4bit_use_double_quant=True, # Using double quantization as mentioned in QLoRA paper
bnb_4bit_compute_dtype=torch.bf16, # During computation, pre-trained model should be loaded in BF16 format
model = AutoModelForCausalLM.from_pretrained(
quantization_config=bnb_config, # Use bitsandbytes config
device_map="auto", # Specifying device_map="auto" so that HF Accelerate will determine which GPU to put each layer of the model on
trust_remote_code=True, # Set trust_remote_code=True to use falcon-7b model with custom code
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) # Set trust_remote_code=True
tokenizer.pad_token = tokenizer.eos_token # Setting pad_token same as eos_token
model = prepare_model_for_kbit_training(model)
lora_alpha = 32 # scaling factor for the weight matrices
lora_dropout = 0.05 # dropout probability of the LoRA layers
lora_rank = 16 # dimension of the low-rank matrices
peft_config = LoraConfig(
bias="none", # setting to 'none' for only training weight params instead of biases
target_modules=[ # Setting names of modules in falcon-7b model that we want to apply LoRA to
peft_model = get_peft_model(model, peft_config)
output_dir = "./falcon-7b-sharded-fp16-finetuned-mental-health-conversational"
per_device_train_batch_size = 16 # reduce batch size by 2x if out-of-memory error
gradient_accumulation_steps = 4 # increase gradient accumulation steps by 2x if batch size is reduced
optim = "paged_adamw_32bit" # activates the paging for better memory management
save_strategy="steps" # checkpoint save strategy to adopt during training
save_steps = 10 # number of updates steps before two checkpoint saves
logging_steps = 10 # number of update steps between two logs if logging_strategy="steps"
learning_rate = 2e-4 # learning rate for AdamW optimizer
max_grad_norm = 0.3 # maximum gradient norm (for gradient clipping)
max_steps = 70 # training will happen for 70 steps
warmup_ratio = 0.03 # number of steps used for a linear warmup from 0 to learning_rate
lr_scheduler_type = "cosine" # learning rate scheduler
training_arguments = TrainingArguments(
trainer = SFTTrainer(
# upcasting the layer norms in torch.bfloat16 for more stable training
for name, module in trainer.model.named_modules():
if "norm" in name:
module =
peft_model.config.use_cache = False
# import gradio as gr
# import torch
# import re, os, warnings
# from langchain import PromptTemplate, LLMChain
# from langchain.llms.base import LLM
# from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
# from peft import LoraConfig, get_peft_model, PeftConfig, PeftModel
# warnings.filterwarnings("ignore")
# def init_model_and_tokenizer(PEFT_MODEL):
# config = PeftConfig.from_pretrained(PEFT_MODEL)
# bnb_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_use_double_quant=True,
# bnb_4bit_compute_dtype=torch.float16,
# )
# peft_base_model = AutoModelForCausalLM.from_pretrained(
# config.base_model_name_or_path,
# return_dict=True,
# quantization_config=bnb_config,
# device_map="auto",
# trust_remote_code=True,
# )
# peft_model = PeftModel.from_pretrained(peft_base_model, PEFT_MODEL)
# peft_tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
# peft_tokenizer.pad_token = peft_tokenizer.eos_token
# return peft_model, peft_tokenizer
# def init_llm_chain(peft_model, peft_tokenizer):
# class CustomLLM(LLM):
# def _call(self, prompt: str, stop=None, run_manager=None) -> str:
# device = "cuda:0"
# peft_encoding = peft_tokenizer(prompt, return_tensors="pt").to(device)
# peft_outputs = peft_model.generate(input_ids=peft_encoding.input_ids, generation_config=GenerationConfig(max_new_tokens=256, pad_token_id = peft_tokenizer.eos_token_id, \
# eos_token_id = peft_tokenizer.eos_token_id, attention_mask = peft_encoding.attention_mask, \
# temperature=0.4, top_p=0.6, repetition_penalty=1.3, num_return_sequences=1,))
# peft_text_output = peft_tokenizer.decode(peft_outputs[0], skip_special_tokens=True)
# return peft_text_output
# @property
# def _llm_type(self) -> str:
# return "custom"
# llm = CustomLLM()
# template = """Answer the following question truthfully.
# If you don't know the answer, respond 'Sorry, I don't know the answer to this question.'.
# If the question is too complex, respond 'Kindly, consult a psychiatrist for further queries.'.
# Example Format:
# <HUMAN>: question here
# <ASSISTANT>: answer here
# Begin!
# <HUMAN>: {query}
# prompt = PromptTemplate(template=template, input_variables=["query"])
# llm_chain = LLMChain(prompt=prompt, llm=llm)
# return llm_chain
# def user(user_message, history):
# return "", history + [[user_message, None]]
# def bot(history):
# if len(history) >= 2:
# query = history[-2][0] + "\n" + history[-2][1] + "\nHere, is the next QUESTION: " + history[-1][0]
# else:
# query = history[-1][0]
# bot_message =
# bot_message = post_process_chat(bot_message)
# history[-1][1] = ""
# history[-1][1] += bot_message
# return history
# def post_process_chat(bot_message):
# try:
# bot_message = re.findall(r"<ASSISTANT>:.*?Begin!", bot_message, re.DOTALL)[1]
# except IndexError:
# pass
# bot_message = re.split(r'<ASSISTANT>\:?\s?', bot_message)[-1].split("Begin!")[0]
# bot_message = re.sub(r"^(.*?\.)(?=\n|$)", r"\1", bot_message, flags=re.DOTALL)
# try:
# bot_message ="(.*\.)", bot_message, re.DOTALL).group(1)
# except AttributeError:
# pass
# bot_message = re.sub(r"\n\d.$", "", bot_message)
# bot_message = re.split(r"(Goodbye|Take care|Best Wishes)", bot_message, flags=re.IGNORECASE)[0].strip()
# bot_message = bot_message.replace("\n\n", "\n")
# return bot_message
# model = "heliosbrahma/falcon-7b-sharded-bf16-finetuned-mental-health-conversational"
# peft_model, peft_tokenizer = init_model_and_tokenizer(PEFT_MODEL = model)
# with gr.Blocks() as interface:
# gr.HTML("""<h1>Welcome to Mental Health Conversational AI</h1>""")
# gr.Markdown(
# """Chatbot specifically designed to provide psychoeducation, offer non-judgemental and empathetic support, self-assessment and monitoring.<br>
# Get instant response for any mental health related queries. If the chatbot seems you need external support, then it will respond appropriately.<br>"""
# )
# chatbot = gr.Chatbot()
# query = gr.Textbox(label="Type your query here, then press 'enter' and scroll up for response")
# clear = gr.Button(value="Clear Chat History!")
# llm_chain = init_llm_chain(peft_model, peft_tokenizer)
# query.submit(user, [query, chatbot], [query, chatbot], queue=False).then(bot, chatbot, chatbot)
# None, None, chatbot, queue=False)
# interface.queue().launch() |