{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "QyqHXpfzNTi5", "outputId": "9612f9b9-1b51-4469-e50c-d0a0ffdc8e9c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n" ] } ], "source": [ "!pip install gradio==3.41.0 transformers==4.32.0 langchain==0.0.273 -Uqqq\n", "!pip install accelerate==0.12.0 bitsandbytes==0.41.1 einops==0.7.0 peft==0.5.0 -Uqqq" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7uflzy2_OAjL" }, "outputs": [], "source": [ "import gradio as gr\n", "import torch\n", "import re, os, warnings\n", "from langchain import PromptTemplate, LLMChain\n", "from langchain.llms.base import LLM\n", "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig\n", "from peft import LoraConfig, get_peft_model, PeftConfig, PeftModel\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OYx4Jyh2OV17" }, "outputs": [], "source": [ "# initialize and load PEFT model and tokenizer\n", "def init_model_and_tokenizer(PEFT_MODEL):\n", " config = PeftConfig.from_pretrained(PEFT_MODEL)\n", " bnb_config = BitsAndBytesConfig(\n", " load_in_4bit=True,\n", " bnb_4bit_quant_type=\"nf4\",\n", " bnb_4bit_use_double_quant=True,\n", " bnb_4bit_compute_dtype=torch.float16,\n", " )\n", "\n", " peft_base_model = AutoModelForCausalLM.from_pretrained(\n", " config.base_model_name_or_path,\n", " return_dict=True,\n", " quantization_config=bnb_config,\n", " device_map=\"auto\",\n", " trust_remote_code=True,\n", " )\n", "\n", " peft_model = PeftModel.from_pretrained(peft_base_model, PEFT_MODEL)\n", "\n", " peft_tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n", " peft_tokenizer.pad_token = peft_tokenizer.eos_token\n", "\n", " return peft_model, peft_tokenizer" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Azg2zEW2OkP7" }, "outputs": [], "source": [ "# custom LLM chain to generate answer from PEFT model for each query\n", "def init_llm_chain(peft_model, peft_tokenizer):\n", " class CustomLLM(LLM):\n", " def _call(self, prompt: str, stop=None, run_manager=None) -> str:\n", " device = \"cuda:0\"\n", " peft_encoding = peft_tokenizer(prompt, return_tensors=\"pt\").to(device)\n", " peft_outputs = peft_model.generate(input_ids=peft_encoding.input_ids, generation_config=GenerationConfig(max_new_tokens=256, pad_token_id = peft_tokenizer.eos_token_id, \\\n", " eos_token_id = peft_tokenizer.eos_token_id, attention_mask = peft_encoding.attention_mask, \\\n", " temperature=0.4, top_p=0.6, repetition_penalty=1.3, num_return_sequences=1,))\n", " peft_text_output = peft_tokenizer.decode(peft_outputs[0], skip_special_tokens=True)\n", " return peft_text_output\n", "\n", " @property\n", " def _llm_type(self) -> str:\n", " return \"custom\"\n", "\n", " llm = CustomLLM()\n", "\n", " template = \"\"\"Answer the following question truthfully.\n", " If you don't know the answer, respond 'Sorry, I don't know the answer to this question.'.\n", " If the question is too complex, respond 'Kindly, consult a psychiatrist for further queries.'.\n", "\n", " Example Format:\n", " : question here\n", " : answer here\n", "\n", " Begin!\n", "\n", " : {query}\n", " :\"\"\"\n", "\n", " prompt = PromptTemplate(template=template, input_variables=[\"query\"])\n", " llm_chain = LLMChain(prompt=prompt, llm=llm)\n", "\n", " return llm_chain" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "10h_KVGilk2J" }, "outputs": [], "source": [ "def user(user_message, history):\n", " return \"\", history + [[user_message, None]]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QeFE1qZnluMm" }, "outputs": [], "source": [ "def bot(history):\n", " if len(history) >= 2:\n", " query = history[-2][0] + \"\\n\" + history[-2][1] + \"\\nHere, is the next QUESTION: \" + history[-1][0]\n", " else:\n", " query = history[-1][0]\n", "\n", " bot_message = llm_chain.run(query)\n", " bot_message = post_process_chat(bot_message)\n", "\n", " history[-1][1] = \"\"\n", " history[-1][1] += bot_message\n", " return history" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aae3uAD5lyXN" }, "outputs": [], "source": [ "def post_process_chat(bot_message):\n", " try:\n", " bot_message = re.findall(r\":.*?Begin!\", bot_message, re.DOTALL)[1]\n", " except IndexError:\n", " pass\n", "\n", " bot_message = re.split(r'\\:?\\s?', bot_message)[-1].split(\"Begin!\")[0]\n", "\n", " bot_message = re.sub(r\"^(.*?\\.)(?=\\n|$)\", r\"\\1\", bot_message, flags=re.DOTALL)\n", " try:\n", " bot_message = re.search(r\"(.*\\.)\", bot_message, re.DOTALL).group(1)\n", " except AttributeError:\n", " pass\n", "\n", " bot_message = re.sub(r\"\\n\\d.$\", \"\", bot_message)\n", " bot_message = re.split(r\"(Goodbye|Take care|Best Wishes)\", bot_message, flags=re.IGNORECASE)[0].strip()\n", " bot_message = bot_message.replace(\"\\n\\n\", \"\\n\")\n", "\n", " return bot_message" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 49, "referenced_widgets": [ "34b559e7dcc245d59cf63059f89854e1", "c6e4d63554564591ab38a4e633f60ba4", "822036e3276b40a98c989dd7af3b690d", "e9c50bc4ec04407386d5a4dea42a18bd", "350fa3821e104536b42c1f70985f1ee4", "075c92e6397242e495834f5dbdd074fe", "06c033e936fb4b9faaa2a13150855a04", "3306aab2825e424db73536be38fe774e", "9c6feacfae344692b5c71d16d687c74e", "20c8b95e503a44de81905b652b9291c7", "b05c18c3366a48bf81c1260e553995ff" ] }, "id": "cjZ9ENNnSpeY", "outputId": "671b81ea-4789-4555-e96a-c69797cb6a13" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "34b559e7dcc245d59cf63059f89854e1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/8 [00:00Welcome to Mental Health Conversational AI\"\"\")\n", " gr.Markdown(\n", " \"\"\"Chatbot specifically designed to provide psychoeducation, offer non-judgemental and empathetic support, self-assessment and monitoring.
\n", " Get instant response for any mental health related queries. If the chatbot seems you need external support, then it will respond appropriately.
\"\"\"\n", " )\n", "\n", " chatbot = gr.Chatbot()\n", " query = gr.Textbox(label=\"Type your query here, then press 'enter' and scroll up for response\")\n", " clear = gr.Button(value=\"Clear Chat History!\")\n", " clear.style(size=\"sm\")\n", "\n", " llm_chain = init_llm_chain(peft_model, peft_tokenizer)\n", "\n", " query.submit(user, [query, chatbot], [query, chatbot], queue=False).then(bot, chatbot, chatbot)\n", " clear.click(lambda: None, None, chatbot, queue=False)\n", "\n", "demo.queue().launch()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bl3UYt3dUF6H" }, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "machine_shape": "hm", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "06c033e936fb4b9faaa2a13150855a04": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "075c92e6397242e495834f5dbdd074fe": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "20c8b95e503a44de81905b652b9291c7": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "3306aab2825e424db73536be38fe774e": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "34b559e7dcc245d59cf63059f89854e1": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_c6e4d63554564591ab38a4e633f60ba4", "IPY_MODEL_822036e3276b40a98c989dd7af3b690d", "IPY_MODEL_e9c50bc4ec04407386d5a4dea42a18bd" ], "layout": "IPY_MODEL_350fa3821e104536b42c1f70985f1ee4" } }, "350fa3821e104536b42c1f70985f1ee4": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "822036e3276b40a98c989dd7af3b690d": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_3306aab2825e424db73536be38fe774e", "max": 8, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_9c6feacfae344692b5c71d16d687c74e", "value": 8 } }, "9c6feacfae344692b5c71d16d687c74e": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "b05c18c3366a48bf81c1260e553995ff": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "c6e4d63554564591ab38a4e633f60ba4": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_075c92e6397242e495834f5dbdd074fe", "placeholder": "​", "style": "IPY_MODEL_06c033e936fb4b9faaa2a13150855a04", "value": "Loading checkpoint shards: 100%" } }, "e9c50bc4ec04407386d5a4dea42a18bd": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_20c8b95e503a44de81905b652b9291c7", "placeholder": "​", "style": "IPY_MODEL_b05c18c3366a48bf81c1260e553995ff", "value": " 8/8 [01:32<00:00, 9.87s/it]" } } } } }, "nbformat": 4, "nbformat_minor": 0 }