{ "cells": [ { "cell_type": "markdown", "id": "624e22d0-f2b3-4708-bc4d-d96bce26cdde", "metadata": {}, "source": [ "# Mental Health Nudging with Generative AI Demo\n", "\n", "This code is duplicated in the `app.py` file.\n", "This notebook is provided for the ease of development and debugging." ] }, { "cell_type": "code", "execution_count": 1, "id": "6b57ced9-62ee-44a0-a895-6ed288f970ff", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7860\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import io\n", "import base64\n", "import gradio as gr\n", "from PIL import Image\n", "from openai import OpenAI\n", "\n", "\n", "def run_demo():\n", " \"\"\"Setup the app interface and launch it.\"\"\"\n", " with gr.Blocks() as app:\n", "\n", " gr.Markdown('# Mental Health Nudging with Generative AI Demo')\n", " with gr.Row():\n", "\n", " # input features\n", " with gr.Column(scale=2):\n", "\n", " # demographics\n", " gender = gr.Radio(label='Gender', value='N/A',\n", " choices=['Male', 'Female', 'Non-Binary', 'N/A'])\n", " age = gr.Slider(label='Age', minimum=18, maximum=80, step=1)\n", " race = gr.Radio(label='Race', value='N/A',\n", " choices=['White', 'Hispanic', 'Black', 'Asian', 'N/A'])\n", "\n", " # symptoms\n", " disorders = ['Sadness', 'Inability to concentrate', 'Anxiety', 'Extreme mood changes',\n", " 'Social withdrawal', 'Tiredness', 'Lack of appetite', 'Increased appetite']\n", " symptoms = gr.CheckboxGroup(label='Symptoms', choices=disorders)\n", "\n", " # interests\n", " interests = gr.Textbox(label='Interests', placeholder='Comma-separated list of interests...')\n", "\n", " # submit button\n", " submit_button = gr.Button('Generate Nudge')\n", "\n", " # resulting nudge\n", " with gr.Column(scale=1):\n", " nudge_image = gr.Image(label='Nudge Image')\n", " nudge_message = gr.Textbox(label='Nudge Message')\n", "\n", " # submit parameters for nudge generation\n", " inputs = [gender, age, race, interests, symptoms]\n", " outputs = [nudge_image, nudge_message]\n", " submit_button.click(fn=generate, inputs=inputs, outputs=outputs)\n", "\n", " # launch the app\n", " gr.close_all()\n", " app.queue(default_concurrency_limit=None)\n", " app.launch()\n", "\n", "\n", "def generate(gender, age, race, interests, symptoms):\n", " \"\"\"Generate nudging image and message for the given person.\"\"\"\n", " nudge_message = generate_nudge_message(gender, age, interests, symptoms)\n", " nudge_image = generate_nudge_image(gender, age, race, nudge_message)\n", " return nudge_image, nudge_message\n", "\n", "\n", "def generate_nudge_message(gender, age, interests, symptoms):\n", " \"\"\"Generate a message for a given person.\"\"\"\n", " # construct description of the person\n", " desc = f'A {age} year old '\n", " if gender == 'Male':\n", " desc += 'man.'\n", " elif gender == 'Female':\n", " desc += 'woman.'\n", " elif gender == 'Non-Binary':\n", " desc += 'non-binary person.'\n", " else:\n", " desc += 'person.'\n", " if interests:\n", " desc += f' They like {interests}.'\n", " if symptoms:\n", " desc += f' They have the following mental health symptoms: {\", \".join(map(str.lower, symptoms))}.'\n", " else:\n", " desc += f' They do not have any mental health symptoms.'\n", "\n", " # generate nudge message\n", " system_prompt = 'You are writing motivational text messages to help people with their mental health. '\\\n", " + 'Messages should be friendly and positive, but also professional and super short. '\\\n", " + 'You are limited on space. Messages should be written at the reading level of an eighth grader. '\\\n", " + 'Word choice should be short and simple so everyone can understand. \\n\\n'\\\n", " + 'You will be given some basic information about the person you are addressing. '\\\n", " + 'Messages should be short, so be discerning. You should try to use the person\\'s '\\\n", " + 'information to give them relevant and actionable tips for improving their mental health symptoms.'\n", " user_prompt = f'Write a short inspirational message for the person with the following description:\\n\\n{desc}'\n", " messages = [{'role': 'system', 'content': f'{system_prompt}'},\n", " {'role': 'user', 'content': f'{user_prompt}'}]\n", " completion = client.chat.completions.create(messages=messages, model='gpt-3.5-turbo', temperature=.5)\n", " nudge_message = completion.choices[0].message.content\n", "\n", " return nudge_message\n", "\n", "\n", "def generate_nudge_image(gender, age, race, nudge_message):\n", " \"\"\"Generate an image for a given person and message.\"\"\"\n", " # construct description of the person\n", " desc = f'a {age} year old '\n", " if race != 'N/A':\n", " desc += f'{race.lower()} '\n", " if gender == 'Male':\n", " desc += 'man.'\n", " elif gender == 'Female':\n", " desc += 'woman.'\n", " elif gender == 'Non-Binary':\n", " desc += 'non-binary person.'\n", " else:\n", " desc += 'person.'\n", "\n", " # generate nudge image\n", " prompt = 'Illustrate one simple, inspirational, fun image to help a person with their mental health. NO TEXT. '\\\n", " + f'The style is cute and illustrative. It is focused on {desc} '\\\n", " + f'The image should suit the following message:\\n\\n{nudge_message}'\n", " response = client.images.generate(prompt=prompt, model='dall-e-3', response_format='b64_json')\n", " nudge_image = Image.open(io.BytesIO(base64.b64decode(response.data[0].b64_json)))\n", "\n", " return nudge_image\n", "\n", "\n", "if __name__ == '__main__':\n", " client = OpenAI()\n", " run_demo()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "6bf1dd4a-b7c5-496d-9a7a-b4a392e185e2", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }