File size: 7,919 Bytes
05094a9
 
1628190
 
 
 
 
 
 
 
 
 
 
05094a9
 
1628190
05094a9
 
 
 
 
 
 
1628190
05094a9
 
 
 
 
 
 
1628190
05094a9
 
 
 
 
 
 
 
 
 
d070bf6
 
05094a9
 
 
 
 
 
 
 
 
 
 
 
 
f5a32f3
05094a9
 
1628190
 
05094a9
1628190
 
28e0d23
 
a306a6c
 
6ae4710
05094a9
 
6ae4710
05094a9
 
 
 
 
 
 
 
 
 
6ae4710
05094a9
 
 
 
 
 
 
 
 
6ae4710
05094a9
6ae4710
1628190
6ae4710
 
 
 
 
05094a9
6ae4710
 
 
 
 
 
 
28e0d23
6ae4710
 
 
05094a9
6ae4710
05094a9
28e0d23
6ae4710
 
1628190
 
 
 
 
 
 
05094a9
 
 
28e0d23
05094a9
1628190
05094a9
 
 
1628190
05094a9
6ae4710
a306a6c
1628190
6ae4710
 
 
 
 
 
 
 
 
1628190
 
a306a6c
 
 
d070bf6
 
05094a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1628190
05094a9
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "624e22d0-f2b3-4708-bc4d-d96bce26cdde",
   "metadata": {},
   "source": [
    "# Mental Health Nudging with Generative AI Demo\n",
    "\n",
    "This code is duplicated in the `app.py` file.\n",
    "This notebook is provided for the ease of development and debugging."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6b57ced9-62ee-44a0-a895-6ed288f970ff",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running on local URL:  http://127.0.0.1:7860\n",
      "\n",
      "To create a public link, set `share=True` in `launch()`.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div><iframe src=\"http://127.0.0.1:7860/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import io\n",
    "import base64\n",
    "import gradio as gr\n",
    "from PIL import Image\n",
    "from openai import OpenAI\n",
    "\n",
    "\n",
    "def run_demo():\n",
    "    \"\"\"Setup the app interface and launch it.\"\"\"\n",
    "    with gr.Blocks() as app:\n",
    "\n",
    "        gr.Markdown('# Mental Health Nudging with Generative AI Demo')\n",
    "        with gr.Row():\n",
    "\n",
    "            # input features\n",
    "            with gr.Column(scale=2):\n",
    "\n",
    "                # demographics\n",
    "                gender = gr.Radio(label='Gender', value='N/A',\n",
    "                                  choices=['Male', 'Female', 'Non-Binary', 'N/A'])\n",
    "                age = gr.Slider(label='Age', minimum=18, maximum=80, step=1)\n",
    "                race = gr.Radio(label='Race', value='N/A',\n",
    "                                choices=['White', 'Hispanic', 'Black', 'Asian', 'N/A'])\n",
    "\n",
    "                # symptoms\n",
    "                disorders = ['Sadness', 'Inability to concentrate', 'Anxiety', 'Extreme mood changes',\n",
    "                             'Social withdrawal', 'Tiredness', 'Lack of appetite', 'Increased appetite']\n",
    "                symptoms = gr.CheckboxGroup(label='Symptoms', choices=disorders)\n",
    "\n",
    "                # interests\n",
    "                interests = gr.Textbox(label='Interests', placeholder='Comma-separated list of interests...')\n",
    "\n",
    "                # submit button\n",
    "                submit_button = gr.Button('Generate Nudge')\n",
    "\n",
    "            # resulting nudge\n",
    "            with gr.Column(scale=1):\n",
    "                nudge_image = gr.Image(label='Nudge Image')\n",
    "                nudge_message = gr.Textbox(label='Nudge Message')\n",
    "\n",
    "        # submit parameters for nudge generation\n",
    "        inputs = [gender, age, race, interests, symptoms]\n",
    "        outputs = [nudge_image, nudge_message]\n",
    "        submit_button.click(fn=generate, inputs=inputs, outputs=outputs)\n",
    "\n",
    "    # launch the app\n",
    "    gr.close_all()\n",
    "    app.queue(default_concurrency_limit=None)\n",
    "    app.launch()\n",
    "\n",
    "\n",
    "def generate(gender, age, race, interests, symptoms):\n",
    "    \"\"\"Generate nudging image and message for the given person.\"\"\"\n",
    "    nudge_message = generate_nudge_message(gender, age, interests, symptoms)\n",
    "    nudge_image = generate_nudge_image(gender, age, race, nudge_message)\n",
    "    return nudge_image, nudge_message\n",
    "\n",
    "\n",
    "def generate_nudge_message(gender, age, interests, symptoms):\n",
    "    \"\"\"Generate a message for a given person.\"\"\"\n",
    "    # construct description of the person\n",
    "    desc = f'A {age} year old '\n",
    "    if gender == 'Male':\n",
    "        desc += 'man.'\n",
    "    elif gender == 'Female':\n",
    "        desc += 'woman.'\n",
    "    elif gender == 'Non-Binary':\n",
    "        desc += 'non-binary person.'\n",
    "    else:\n",
    "        desc += 'person.'\n",
    "    if interests:\n",
    "        desc += f' They like {interests}.'\n",
    "    if symptoms:\n",
    "        desc += f' They have the following mental health symptoms: {\", \".join(map(str.lower, symptoms))}.'\n",
    "    else:\n",
    "        desc += f' They do not have any mental health symptoms.'\n",
    "\n",
    "    # generate nudge message\n",
    "    system_prompt = 'You are writing motivational text messages to help people with their mental health. '\\\n",
    "                  + 'Messages should be friendly and positive, but also professional and super short. '\\\n",
    "                  + 'You are limited on space. Messages should be written at the reading level of an eighth grader. '\\\n",
    "                  + 'Word choice should be short and simple so everyone can understand. \\n\\n'\\\n",
    "                  + 'You will be given some basic information about the person you are addressing. '\\\n",
    "                  + 'Messages should be short, so be discerning. You should try to use the person\\'s '\\\n",
    "                  + 'information to give them relevant and actionable tips for improving their mental health symptoms.'\n",
    "    user_prompt = f'Write a short inspirational message for the person with the following description:\\n\\n{desc}'\n",
    "    messages = [{'role': 'system', 'content': f'{system_prompt}'},\n",
    "                {'role': 'user', 'content': f'{user_prompt}'}]\n",
    "    completion = client.chat.completions.create(messages=messages, model='gpt-3.5-turbo', temperature=.5)\n",
    "    nudge_message = completion.choices[0].message.content\n",
    "\n",
    "    return nudge_message\n",
    "\n",
    "\n",
    "def generate_nudge_image(gender, age, race, nudge_message):\n",
    "    \"\"\"Generate an image for a given person and message.\"\"\"\n",
    "    # construct description of the person\n",
    "    desc = f'a {age} year old '\n",
    "    if race != 'N/A':\n",
    "        desc += f'{race.lower()} '\n",
    "    if gender == 'Male':\n",
    "        desc += 'man.'\n",
    "    elif gender == 'Female':\n",
    "        desc += 'woman.'\n",
    "    elif gender == 'Non-Binary':\n",
    "        desc += 'non-binary person.'\n",
    "    else:\n",
    "        desc += 'person.'\n",
    "\n",
    "    # generate nudge image\n",
    "    prompt = 'Illustrate one simple, inspirational, fun image to help a person with their mental health. NO TEXT. '\\\n",
    "           + f'The style is cute and illustrative. It is focused on {desc} '\\\n",
    "           + f'The image should suit the following message:\\n\\n{nudge_message}'\n",
    "    response = client.images.generate(prompt=prompt, model='dall-e-3', response_format='b64_json')\n",
    "    nudge_image = Image.open(io.BytesIO(base64.b64decode(response.data[0].b64_json)))\n",
    "\n",
    "    return nudge_image\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    client = OpenAI()\n",
    "    run_demo()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bf1dd4a-b7c5-496d-9a7a-b4a392e185e2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}