sukiboo commited on
Commit
05094a9
·
1 Parent(s): a06ff9d

initial code upload

Browse files
Files changed (1) hide show
  1. app.ipynb +271 -0
app.ipynb ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 71,
6
+ "id": "6b57ced9-62ee-44a0-a895-6ed288f970ff",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "Running on local URL: http://127.0.0.1:7899\n",
14
+ "\n",
15
+ "To create a public link, set `share=True` in `launch()`.\n"
16
+ ]
17
+ },
18
+ {
19
+ "data": {
20
+ "text/html": [
21
+ "<div><iframe src=\"http://127.0.0.1:7899/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
22
+ ],
23
+ "text/plain": [
24
+ "<IPython.core.display.HTML object>"
25
+ ]
26
+ },
27
+ "metadata": {},
28
+ "output_type": "display_data"
29
+ },
30
+ {
31
+ "name": "stdout",
32
+ "output_type": "stream",
33
+ "text": [
34
+ "desc='A 42 year old woman. They like cats, cooking. They do not like outside. They have the following symptoms: Anxiety.'\n",
35
+ "prompt='Illustrate an inspirational message to help a person with their mental health. The person is described as follows:\\n\\nA 42 year old woman. They like cats, cooking. They do not like outside. They have the following symptoms: Anxiety.\\n\\nThe image should align with the following message:\\n\\nDear friend, I understand that anxiety can sometimes feel overwhelming, but remember, you are stronger than you think. Just like cats have their own safe spaces, create a sanctuary within yourself where you can find peace. Embrace the comforting art of cooking, allowing it to ignite your passion and bring you joy. Though the outside may seem daunting, know that within your soul lies the power to conquer any fear. Be kind to yourself, take small steps, and remember that you are capable of remarkable things.'\n"
36
+ ]
37
+ }
38
+ ],
39
+ "source": [
40
+ "import numpy as np\n",
41
+ "import gradio as gr\n",
42
+ "import requests\n",
43
+ "from PIL import Image\n",
44
+ "from openai import OpenAI\n",
45
+ "\n",
46
+ "\n",
47
+ "def run_demo():\n",
48
+ " \"\"\"Setup the app interface and launch it.\"\"\"\n",
49
+ " with gr.Blocks() as app:\n",
50
+ "\n",
51
+ " gr.Markdown('# Mental Health Nudging with Generative AI Demo')\n",
52
+ " with gr.Row():\n",
53
+ "\n",
54
+ " # input features\n",
55
+ " with gr.Column(scale=1):\n",
56
+ "\n",
57
+ " # demographics\n",
58
+ " gender = gr.Radio(label='Gender', choices=['Male', 'Female', 'Non-Binary'])\n",
59
+ " age = gr.Slider(label='Age', minimum=18, maximum=80, step=1)\n",
60
+ "\n",
61
+ " # symptoms\n",
62
+ " symptoms = gr.CheckboxGroup(label='Symptoms',\n",
63
+ " choices=['Anxiety', 'Depression', 'Tiredness', 'Apathy'])\n",
64
+ "\n",
65
+ " # interests\n",
66
+ " likes = gr.Textbox(label='Liked Activities',\n",
67
+ " placeholder='Comma-separated list of likes...')\n",
68
+ " dislikes = gr.Textbox(label='Disliked Activities',\n",
69
+ " placeholder='Comma-separated list of dislikes...')\n",
70
+ "\n",
71
+ " # submit button\n",
72
+ " submit_button = gr.Button('Generate Nudge')\n",
73
+ "\n",
74
+ " # resulting nudge\n",
75
+ " with gr.Column(scale=1):\n",
76
+ " nudge_image = gr.Image(label='Nudge Image')\n",
77
+ " nudge_message = gr.Textbox(label='Nudge Message')\n",
78
+ "\n",
79
+ " # submit parameters for nudge generation\n",
80
+ " inputs = [gender, age, symptoms, likes, dislikes]\n",
81
+ " outputs = [nudge_image, nudge_message]\n",
82
+ " submit_button.click(fn=generate, inputs=inputs, outputs=outputs)\n",
83
+ "\n",
84
+ " # launch the app\n",
85
+ " gr.close_all()\n",
86
+ " app.queue(default_concurrency_limit=None)\n",
87
+ " app.launch()\n",
88
+ "\n",
89
+ "\n",
90
+ "def generate(gender, age, symptoms, likes, dislikes):\n",
91
+ " \"\"\"Generate nudging image and message for the given person.\"\"\"\n",
92
+ " # construct description of the person\n",
93
+ " desc = f'A {age} year old {\"man\" if gender==\"Male\" else \"woman\" if gender==\"Female\" else \"person\"}.'\n",
94
+ " if likes:\n",
95
+ " desc += f' They like {likes}.'\n",
96
+ " if dislikes:\n",
97
+ " desc += f' They do not like {dislikes}.'\n",
98
+ " if symptoms:\n",
99
+ " desc += f' They have the following symptoms: {\", \".join(symptoms)}.'\n",
100
+ " else:\n",
101
+ " desc += f' They do not have any symptoms.'\n",
102
+ "\n",
103
+ " # generate nudge for the person\n",
104
+ " nudge_message = generate_nudge_message(desc)\n",
105
+ " nudge_image = generate_nudge_image(desc, nudge_message, random=False)\n",
106
+ "\n",
107
+ " return nudge_image, nudge_message\n",
108
+ "\n",
109
+ "\n",
110
+ "def generate_nudge_message(desc):\n",
111
+ " \"\"\"Generate a message for a given person.\"\"\"\n",
112
+ " print(f'{desc=}')\n",
113
+ " system_prompt = 'You are a behavioral content writer. You job is to write short '\\\n",
114
+ " + 'motivational messages to help people with their mental health. '\\\n",
115
+ " + 'Messages should be at most a paprgraph long.'\n",
116
+ " user_prompt = f'Write a short inspirational message for the person with the following description:\\n\\n{desc}'\n",
117
+ " messages = [{'role': 'system', 'content': f'{system_prompt}'},\n",
118
+ " {'role': 'user', 'content': f'{user_prompt}'}]\n",
119
+ " completion = client.chat.completions.create(messages=messages, model='gpt-3.5-turbo')\n",
120
+ " nudge_message = completion.choices[0].message.content\n",
121
+ " return nudge_message\n",
122
+ "\n",
123
+ "\n",
124
+ "def generate_nudge_image(desc, nudge_message, random=True):\n",
125
+ " \"\"\"Generate an image for a given person and message.\"\"\"\n",
126
+ " if random:\n",
127
+ " return Image.fromarray(np.random.randint(0, 255, (100, 100, 3), dtype='uint8'), 'RGB')\n",
128
+ "\n",
129
+ " prompt = 'Illustrate an inspirational message to help a person with their mental health. '\\\n",
130
+ " + f'The person is described as follows:\\n\\n{desc}\\n\\n'\\\n",
131
+ " + f'The image should align with the following message:\\n\\n{nudge_message}'\n",
132
+ "\n",
133
+ " print(f'{prompt=}')\n",
134
+ " response = client.images.generate(prompt=prompt, model='dall-e-2', style='vivid', quality='standard')\n",
135
+ " nudge_image = Image.open(requests.get(response.data[0].url, stream=True).raw)\n",
136
+ "\n",
137
+ " return nudge_image\n",
138
+ "\n",
139
+ "\n",
140
+ "if __name__ == '__main__':\n",
141
+ " client = OpenAI()\n",
142
+ " run_demo()\n"
143
+ ]
144
+ },
145
+ {
146
+ "cell_type": "code",
147
+ "execution_count": null,
148
+ "id": "f5dbbb9e-e706-4789-82d2-0f83b80c65da",
149
+ "metadata": {},
150
+ "outputs": [],
151
+ "source": []
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": null,
156
+ "id": "6bf1dd4a-b7c5-496d-9a7a-b4a392e185e2",
157
+ "metadata": {},
158
+ "outputs": [],
159
+ "source": []
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": null,
164
+ "id": "83da6055-6ca8-44ee-a64f-be3b29fc400a",
165
+ "metadata": {},
166
+ "outputs": [],
167
+ "source": []
168
+ },
169
+ {
170
+ "cell_type": "code",
171
+ "execution_count": null,
172
+ "id": "24bbd8a2-600e-4b46-aa58-95caa76c2c88",
173
+ "metadata": {},
174
+ "outputs": [],
175
+ "source": []
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "execution_count": null,
180
+ "id": "7205584e-9526-42b0-8f8f-a7c3901831d1",
181
+ "metadata": {},
182
+ "outputs": [],
183
+ "source": []
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "execution_count": null,
188
+ "id": "c7cfcc42-b369-465c-902f-e01b4e017fb8",
189
+ "metadata": {},
190
+ "outputs": [],
191
+ "source": []
192
+ },
193
+ {
194
+ "cell_type": "code",
195
+ "execution_count": null,
196
+ "id": "392dbdda-2a94-4671-b935-4a4706d5b5f6",
197
+ "metadata": {},
198
+ "outputs": [],
199
+ "source": []
200
+ },
201
+ {
202
+ "cell_type": "code",
203
+ "execution_count": null,
204
+ "id": "ec9c7160-a50c-4f2d-9450-70586157ca05",
205
+ "metadata": {},
206
+ "outputs": [],
207
+ "source": []
208
+ },
209
+ {
210
+ "cell_type": "code",
211
+ "execution_count": null,
212
+ "id": "08545064-3b81-4216-bce7-df58bdef7aeb",
213
+ "metadata": {},
214
+ "outputs": [],
215
+ "source": []
216
+ },
217
+ {
218
+ "cell_type": "code",
219
+ "execution_count": null,
220
+ "id": "dd6261bf-f3bd-45cb-bcc7-d517d44f7282",
221
+ "metadata": {},
222
+ "outputs": [],
223
+ "source": []
224
+ },
225
+ {
226
+ "cell_type": "code",
227
+ "execution_count": null,
228
+ "id": "1e14977a-be46-4d92-a512-cf19bb39d4af",
229
+ "metadata": {},
230
+ "outputs": [],
231
+ "source": []
232
+ },
233
+ {
234
+ "cell_type": "code",
235
+ "execution_count": null,
236
+ "id": "86b9de25-7c42-4e6f-bbf5-b0cb78fd0331",
237
+ "metadata": {},
238
+ "outputs": [],
239
+ "source": []
240
+ },
241
+ {
242
+ "cell_type": "code",
243
+ "execution_count": null,
244
+ "id": "ec4e3629-07a5-4121-8752-8f6cfe20b561",
245
+ "metadata": {},
246
+ "outputs": [],
247
+ "source": []
248
+ }
249
+ ],
250
+ "metadata": {
251
+ "kernelspec": {
252
+ "display_name": "Python 3 (ipykernel)",
253
+ "language": "python",
254
+ "name": "python3"
255
+ },
256
+ "language_info": {
257
+ "codemirror_mode": {
258
+ "name": "ipython",
259
+ "version": 3
260
+ },
261
+ "file_extension": ".py",
262
+ "mimetype": "text/x-python",
263
+ "name": "python",
264
+ "nbconvert_exporter": "python",
265
+ "pygments_lexer": "ipython3",
266
+ "version": "3.10.12"
267
+ }
268
+ },
269
+ "nbformat": 4,
270
+ "nbformat_minor": 5
271
+ }