Rohan Kumar Singh commited on
Commit
e2d2960
1 Parent(s): 0fcc9d3

initial commit

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ best-model.ckpt filter=lfs diff=lfs merge=lfs -text
.ipynb_checkpoints/Untitled-checkpoint.ipynb ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "e0102cb4",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "Global seed set to 100\n"
14
+ ]
15
+ },
16
+ {
17
+ "data": {
18
+ "text/plain": [
19
+ "100"
20
+ ]
21
+ },
22
+ "execution_count": 1,
23
+ "metadata": {},
24
+ "output_type": "execute_result"
25
+ }
26
+ ],
27
+ "source": [
28
+ "from transformers import T5Tokenizer, T5ForConditionalGeneration \n",
29
+ "\n",
30
+ "from transformers import AdamW\n",
31
+ "import pandas as pd\n",
32
+ "import torch\n",
33
+ "import pytorch_lightning as pl\n",
34
+ "from pytorch_lightning.callbacks import ModelCheckpoint\n",
35
+ "from torch.nn.utils.rnn import pad_sequence\n",
36
+ "# from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler\n",
37
+ "\n",
38
+ "pl.seed_everything(100)"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": 2,
44
+ "id": "1ec5ec2a",
45
+ "metadata": {},
46
+ "outputs": [],
47
+ "source": [
48
+ "MODEL_NAME='t5-base'"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": 3,
54
+ "id": "8044c622",
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
59
+ "INPUT_MAX_LEN = 128 \n",
60
+ "OUTPUT_MAX_LEN = 128"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": 4,
66
+ "id": "6390f2de",
67
+ "metadata": {},
68
+ "outputs": [],
69
+ "source": [
70
+ "tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, model_max_length=512)"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": 5,
76
+ "id": "8eec35d1",
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "class T5Model(pl.LightningModule):\n",
81
+ " \n",
82
+ " def __init__(self):\n",
83
+ " super().__init__()\n",
84
+ " self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict = True)\n",
85
+ "\n",
86
+ " \n",
87
+ " def forward(self, input_ids, attention_mask, labels=None):\n",
88
+ " \n",
89
+ " output = self.model(\n",
90
+ " input_ids=input_ids, \n",
91
+ " attention_mask=attention_mask, \n",
92
+ " labels=labels\n",
93
+ " )\n",
94
+ " return output.loss, output.logits\n",
95
+ " \n",
96
+ " def training_step(self, batch, batch_idx):\n",
97
+ "\n",
98
+ " input_ids = batch[\"input_ids\"]\n",
99
+ " attention_mask = batch[\"attention_mask\"]\n",
100
+ " labels= batch[\"target\"]\n",
101
+ " loss, logits = self(input_ids , attention_mask, labels)\n",
102
+ "\n",
103
+ " \n",
104
+ " self.log(\"train_loss\", loss, prog_bar=True, logger=True)\n",
105
+ "\n",
106
+ " return {'loss': loss}\n",
107
+ " \n",
108
+ " def validation_step(self, batch, batch_idx):\n",
109
+ " input_ids = batch[\"input_ids\"]\n",
110
+ " attention_mask = batch[\"attention_mask\"]\n",
111
+ " labels= batch[\"target\"]\n",
112
+ " loss, logits = self(input_ids, attention_mask, labels)\n",
113
+ "\n",
114
+ " self.log(\"val_loss\", loss, prog_bar=True, logger=True)\n",
115
+ " \n",
116
+ " return {'val_loss': loss}\n",
117
+ "\n",
118
+ " def configure_optimizers(self):\n",
119
+ " return AdamW(self.parameters(), lr=0.0001)"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": 6,
125
+ "id": "e9d96844",
126
+ "metadata": {},
127
+ "outputs": [
128
+ {
129
+ "name": "stderr",
130
+ "output_type": "stream",
131
+ "text": [
132
+ "Lightning automatically upgraded your loaded checkpoint from v1.9.3 to v2.0.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file F:\\Projects & Open_source\\Chatbot_T5_kaggle\\best-model.ckpt`\n"
133
+ ]
134
+ }
135
+ ],
136
+ "source": [
137
+ "train_model = T5Model.load_from_checkpoint('best-model.ckpt',map_location=DEVICE)"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": 7,
143
+ "id": "3449943f",
144
+ "metadata": {},
145
+ "outputs": [],
146
+ "source": [
147
+ "train_model.freeze()"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "execution_count": 8,
153
+ "id": "0e9f1058",
154
+ "metadata": {},
155
+ "outputs": [],
156
+ "source": [
157
+ "def generate_question(question):\n",
158
+ "\n",
159
+ " inputs_encoding = tokenizer(\n",
160
+ " question,\n",
161
+ " add_special_tokens=True,\n",
162
+ " max_length= INPUT_MAX_LEN,\n",
163
+ " padding = 'max_length',\n",
164
+ " truncation='only_first',\n",
165
+ " return_attention_mask=True,\n",
166
+ " return_tensors=\"pt\"\n",
167
+ " )\n",
168
+ "\n",
169
+ " \n",
170
+ " generate_ids = train_model.model.generate(\n",
171
+ " input_ids = inputs_encoding[\"input_ids\"],\n",
172
+ " attention_mask = inputs_encoding[\"attention_mask\"],\n",
173
+ " max_length = INPUT_MAX_LEN,\n",
174
+ " num_beams = 4,\n",
175
+ " num_return_sequences = 1,\n",
176
+ " no_repeat_ngram_size=2,\n",
177
+ " early_stopping=True,\n",
178
+ " )\n",
179
+ "\n",
180
+ " preds = [\n",
181
+ " tokenizer.decode(gen_id,\n",
182
+ " skip_special_tokens=True, \n",
183
+ " clean_up_tokenization_spaces=True)\n",
184
+ " for gen_id in generate_ids\n",
185
+ " ]\n",
186
+ "\n",
187
+ " return \"\".join(preds)\n"
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "code",
192
+ "execution_count": 9,
193
+ "id": "ee38a88c",
194
+ "metadata": {},
195
+ "outputs": [
196
+ {
197
+ "name": "stdout",
198
+ "output_type": "stream",
199
+ "text": [
200
+ "Ques: hi, how are you doing?\n",
201
+ "BOT: i'm so glad you're doing well.\n"
202
+ ]
203
+ }
204
+ ],
205
+ "source": [
206
+ "ques = \"hi, how are you doing?\"\n",
207
+ "print(\"Ques: \",ques)\n",
208
+ "print(\"BOT: \",generate_question(ques))"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": 11,
214
+ "id": "22aa4414",
215
+ "metadata": {},
216
+ "outputs": [
217
+ {
218
+ "name": "stdout",
219
+ "output_type": "stream",
220
+ "text": [
221
+ "Running on local URL: http://127.0.0.1:7861\n",
222
+ "\n",
223
+ "To create a public link, set `share=True` in `launch()`.\n"
224
+ ]
225
+ },
226
+ {
227
+ "data": {
228
+ "text/html": [
229
+ "<div><iframe src=\"http://127.0.0.1:7861/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
230
+ ],
231
+ "text/plain": [
232
+ "<IPython.core.display.HTML object>"
233
+ ]
234
+ },
235
+ "metadata": {},
236
+ "output_type": "display_data"
237
+ },
238
+ {
239
+ "data": {
240
+ "text/plain": []
241
+ },
242
+ "execution_count": 11,
243
+ "metadata": {},
244
+ "output_type": "execute_result"
245
+ }
246
+ ],
247
+ "source": [
248
+ "import gradio as gr\n",
249
+ "import random\n",
250
+ "import time\n",
251
+ "\n",
252
+ "with gr.Blocks() as demo:\n",
253
+ " chatbot = gr.Chatbot()\n",
254
+ " gr.Chatbot.style(chatbot,height=400)\n",
255
+ " msg = gr.Textbox(info=\"Press \\'Enter\\' to send\")\n",
256
+ " clear = gr.Button(\"Clear\")\n",
257
+ "\n",
258
+ " def user(user_message, history):\n",
259
+ " return \"\", history + [[user_message, None]]\n",
260
+ "\n",
261
+ " def bot(history):\n",
262
+ " bot_message = generate_question(history[-1][0])\n",
263
+ " history[-1][1] = \"\"\n",
264
+ " for character in bot_message:\n",
265
+ " history[-1][1] += character\n",
266
+ " time.sleep(0.05)\n",
267
+ " yield history\n",
268
+ "\n",
269
+ " msg.submit(user, [msg, chatbot], [msg, chatbot], queue=True).then(\n",
270
+ " bot, chatbot, chatbot\n",
271
+ " )\n",
272
+ " clear.click(lambda: None, None, chatbot, queue=True)\n",
273
+ "\n",
274
+ "demo.queue(concurrency_count=2)\n",
275
+ "demo.launch()"
276
+ ]
277
+ },
278
+ {
279
+ "cell_type": "code",
280
+ "execution_count": null,
281
+ "id": "fef38bdc",
282
+ "metadata": {
283
+ "scrolled": true
284
+ },
285
+ "outputs": [],
286
+ "source": [
287
+ "import gradio as gr\n",
288
+ "import random\n",
289
+ "import time\n",
290
+ "\n",
291
+ "with gr.Blocks() as demo:\n",
292
+ " chatbot = gr.Chatbot()\n",
293
+ " msg = gr.Textbox(placeholder='Got any spare time...let\\'s chat!!!')\n",
294
+ " gr.Textbox.style(msg,show_copy_button=True)\n",
295
+ " clear = gr.Button(\"Clear\")\n",
296
+ "\n",
297
+ " def respond(message, chat_history):\n",
298
+ " bot_message = generate_question(message)\n",
299
+ " bot_message = \"**\"+bot_message+\"**\"\n",
300
+ " chat_history.append((message, bot_message))\n",
301
+ " time.sleep(1)\n",
302
+ " return \"\", chat_history\n",
303
+ "\n",
304
+ " msg.submit(respond, [msg, chatbot], [msg, chatbot])\n",
305
+ " clear.click(lambda: None, None, chatbot, queue=False)\n",
306
+ "\n",
307
+ "demo.launch()"
308
+ ]
309
+ },
310
+ {
311
+ "cell_type": "code",
312
+ "execution_count": null,
313
+ "id": "a86d446a",
314
+ "metadata": {},
315
+ "outputs": [],
316
+ "source": []
317
+ }
318
+ ],
319
+ "metadata": {
320
+ "kernelspec": {
321
+ "display_name": "Python 3 (ipykernel)",
322
+ "language": "python",
323
+ "name": "python3"
324
+ },
325
+ "language_info": {
326
+ "codemirror_mode": {
327
+ "name": "ipython",
328
+ "version": 3
329
+ },
330
+ "file_extension": ".py",
331
+ "mimetype": "text/x-python",
332
+ "name": "python",
333
+ "nbconvert_exporter": "python",
334
+ "pygments_lexer": "ipython3",
335
+ "version": "3.10.1"
336
+ }
337
+ },
338
+ "nbformat": 4,
339
+ "nbformat_minor": 5
340
+ }
Untitled.ipynb ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "e0102cb4",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "Global seed set to 100\n"
14
+ ]
15
+ },
16
+ {
17
+ "data": {
18
+ "text/plain": [
19
+ "100"
20
+ ]
21
+ },
22
+ "execution_count": 1,
23
+ "metadata": {},
24
+ "output_type": "execute_result"
25
+ }
26
+ ],
27
+ "source": [
28
+ "from transformers import T5Tokenizer, T5ForConditionalGeneration \n",
29
+ "\n",
30
+ "from transformers import AdamW\n",
31
+ "import pandas as pd\n",
32
+ "import torch\n",
33
+ "import pytorch_lightning as pl\n",
34
+ "from pytorch_lightning.callbacks import ModelCheckpoint\n",
35
+ "from torch.nn.utils.rnn import pad_sequence\n",
36
+ "# from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler\n",
37
+ "\n",
38
+ "pl.seed_everything(100)"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": 2,
44
+ "id": "1ec5ec2a",
45
+ "metadata": {},
46
+ "outputs": [],
47
+ "source": [
48
+ "MODEL_NAME='t5-base'"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": 3,
54
+ "id": "8044c622",
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
59
+ "INPUT_MAX_LEN = 128 \n",
60
+ "OUTPUT_MAX_LEN = 128"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": 4,
66
+ "id": "6390f2de",
67
+ "metadata": {},
68
+ "outputs": [],
69
+ "source": [
70
+ "tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, model_max_length=512)"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": 5,
76
+ "id": "8eec35d1",
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "class T5Model(pl.LightningModule):\n",
81
+ " \n",
82
+ " def __init__(self):\n",
83
+ " super().__init__()\n",
84
+ " self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict = True)\n",
85
+ "\n",
86
+ " \n",
87
+ " def forward(self, input_ids, attention_mask, labels=None):\n",
88
+ " \n",
89
+ " output = self.model(\n",
90
+ " input_ids=input_ids, \n",
91
+ " attention_mask=attention_mask, \n",
92
+ " labels=labels\n",
93
+ " )\n",
94
+ " return output.loss, output.logits\n",
95
+ " \n",
96
+ " def training_step(self, batch, batch_idx):\n",
97
+ "\n",
98
+ " input_ids = batch[\"input_ids\"]\n",
99
+ " attention_mask = batch[\"attention_mask\"]\n",
100
+ " labels= batch[\"target\"]\n",
101
+ " loss, logits = self(input_ids , attention_mask, labels)\n",
102
+ "\n",
103
+ " \n",
104
+ " self.log(\"train_loss\", loss, prog_bar=True, logger=True)\n",
105
+ "\n",
106
+ " return {'loss': loss}\n",
107
+ " \n",
108
+ " def validation_step(self, batch, batch_idx):\n",
109
+ " input_ids = batch[\"input_ids\"]\n",
110
+ " attention_mask = batch[\"attention_mask\"]\n",
111
+ " labels= batch[\"target\"]\n",
112
+ " loss, logits = self(input_ids, attention_mask, labels)\n",
113
+ "\n",
114
+ " self.log(\"val_loss\", loss, prog_bar=True, logger=True)\n",
115
+ " \n",
116
+ " return {'val_loss': loss}\n",
117
+ "\n",
118
+ " def configure_optimizers(self):\n",
119
+ " return AdamW(self.parameters(), lr=0.0001)"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": 6,
125
+ "id": "e9d96844",
126
+ "metadata": {},
127
+ "outputs": [
128
+ {
129
+ "name": "stderr",
130
+ "output_type": "stream",
131
+ "text": [
132
+ "Lightning automatically upgraded your loaded checkpoint from v1.9.3 to v2.0.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file F:\\Projects & Open_source\\Chatbot_T5_kaggle\\best-model.ckpt`\n"
133
+ ]
134
+ }
135
+ ],
136
+ "source": [
137
+ "train_model = T5Model.load_from_checkpoint('best-model.ckpt',map_location=DEVICE)"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": 7,
143
+ "id": "3449943f",
144
+ "metadata": {},
145
+ "outputs": [],
146
+ "source": [
147
+ "train_model.freeze()"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "execution_count": 8,
153
+ "id": "0e9f1058",
154
+ "metadata": {},
155
+ "outputs": [],
156
+ "source": [
157
+ "def generate_question(question):\n",
158
+ "\n",
159
+ " inputs_encoding = tokenizer(\n",
160
+ " question,\n",
161
+ " add_special_tokens=True,\n",
162
+ " max_length= INPUT_MAX_LEN,\n",
163
+ " padding = 'max_length',\n",
164
+ " truncation='only_first',\n",
165
+ " return_attention_mask=True,\n",
166
+ " return_tensors=\"pt\"\n",
167
+ " )\n",
168
+ "\n",
169
+ " \n",
170
+ " generate_ids = train_model.model.generate(\n",
171
+ " input_ids = inputs_encoding[\"input_ids\"],\n",
172
+ " attention_mask = inputs_encoding[\"attention_mask\"],\n",
173
+ " max_length = INPUT_MAX_LEN,\n",
174
+ " num_beams = 4,\n",
175
+ " num_return_sequences = 1,\n",
176
+ " no_repeat_ngram_size=2,\n",
177
+ " early_stopping=True,\n",
178
+ " )\n",
179
+ "\n",
180
+ " preds = [\n",
181
+ " tokenizer.decode(gen_id,\n",
182
+ " skip_special_tokens=True, \n",
183
+ " clean_up_tokenization_spaces=True)\n",
184
+ " for gen_id in generate_ids\n",
185
+ " ]\n",
186
+ "\n",
187
+ " return \"\".join(preds)\n"
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "code",
192
+ "execution_count": 9,
193
+ "id": "ee38a88c",
194
+ "metadata": {},
195
+ "outputs": [
196
+ {
197
+ "name": "stdout",
198
+ "output_type": "stream",
199
+ "text": [
200
+ "Ques: hi, how are you doing?\n",
201
+ "BOT: i'm so glad you're doing well.\n"
202
+ ]
203
+ }
204
+ ],
205
+ "source": [
206
+ "ques = \"hi, how are you doing?\"\n",
207
+ "print(\"Ques: \",ques)\n",
208
+ "print(\"BOT: \",generate_question(ques))"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": 11,
214
+ "id": "22aa4414",
215
+ "metadata": {},
216
+ "outputs": [
217
+ {
218
+ "name": "stdout",
219
+ "output_type": "stream",
220
+ "text": [
221
+ "Running on local URL: http://127.0.0.1:7861\n",
222
+ "\n",
223
+ "To create a public link, set `share=True` in `launch()`.\n"
224
+ ]
225
+ },
226
+ {
227
+ "data": {
228
+ "text/html": [
229
+ "<div><iframe src=\"http://127.0.0.1:7861/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
230
+ ],
231
+ "text/plain": [
232
+ "<IPython.core.display.HTML object>"
233
+ ]
234
+ },
235
+ "metadata": {},
236
+ "output_type": "display_data"
237
+ },
238
+ {
239
+ "data": {
240
+ "text/plain": []
241
+ },
242
+ "execution_count": 11,
243
+ "metadata": {},
244
+ "output_type": "execute_result"
245
+ }
246
+ ],
247
+ "source": [
248
+ "import gradio as gr\n",
249
+ "import random\n",
250
+ "import time\n",
251
+ "\n",
252
+ "with gr.Blocks() as demo:\n",
253
+ " chatbot = gr.Chatbot()\n",
254
+ " gr.Chatbot.style(chatbot,height=400)\n",
255
+ " msg = gr.Textbox(info=\"Press \\'Enter\\' to send\")\n",
256
+ " clear = gr.Button(\"Clear\")\n",
257
+ "\n",
258
+ " def user(user_message, history):\n",
259
+ " return \"\", history + [[user_message, None]]\n",
260
+ "\n",
261
+ " def bot(history):\n",
262
+ " bot_message = generate_question(history[-1][0])\n",
263
+ " history[-1][1] = \"\"\n",
264
+ " for character in bot_message:\n",
265
+ " history[-1][1] += character\n",
266
+ " time.sleep(0.05)\n",
267
+ " yield history\n",
268
+ "\n",
269
+ " msg.submit(user, [msg, chatbot], [msg, chatbot], queue=True).then(\n",
270
+ " bot, chatbot, chatbot\n",
271
+ " )\n",
272
+ " clear.click(lambda: None, None, chatbot, queue=True)\n",
273
+ "\n",
274
+ "demo.queue(concurrency_count=2)\n",
275
+ "demo.launch()"
276
+ ]
277
+ },
278
+ {
279
+ "cell_type": "code",
280
+ "execution_count": null,
281
+ "id": "fef38bdc",
282
+ "metadata": {
283
+ "scrolled": true
284
+ },
285
+ "outputs": [],
286
+ "source": [
287
+ "import gradio as gr\n",
288
+ "import random\n",
289
+ "import time\n",
290
+ "\n",
291
+ "with gr.Blocks() as demo:\n",
292
+ " chatbot = gr.Chatbot()\n",
293
+ " msg = gr.Textbox(placeholder='Got any spare time...let\\'s chat!!!')\n",
294
+ " gr.Textbox.style(msg,show_copy_button=True)\n",
295
+ " clear = gr.Button(\"Clear\")\n",
296
+ "\n",
297
+ " def respond(message, chat_history):\n",
298
+ " bot_message = generate_question(message)\n",
299
+ " bot_message = \"**\"+bot_message+\"**\"\n",
300
+ " chat_history.append((message, bot_message))\n",
301
+ " time.sleep(1)\n",
302
+ " return \"\", chat_history\n",
303
+ "\n",
304
+ " msg.submit(respond, [msg, chatbot], [msg, chatbot])\n",
305
+ " clear.click(lambda: None, None, chatbot, queue=False)\n",
306
+ "\n",
307
+ "demo.launch()"
308
+ ]
309
+ },
310
+ {
311
+ "cell_type": "code",
312
+ "execution_count": null,
313
+ "id": "a86d446a",
314
+ "metadata": {},
315
+ "outputs": [],
316
+ "source": []
317
+ }
318
+ ],
319
+ "metadata": {
320
+ "kernelspec": {
321
+ "display_name": "Python 3 (ipykernel)",
322
+ "language": "python",
323
+ "name": "python3"
324
+ },
325
+ "language_info": {
326
+ "codemirror_mode": {
327
+ "name": "ipython",
328
+ "version": 3
329
+ },
330
+ "file_extension": ".py",
331
+ "mimetype": "text/x-python",
332
+ "name": "python",
333
+ "nbconvert_exporter": "python",
334
+ "pygments_lexer": "ipython3",
335
+ "version": "3.10.1"
336
+ }
337
+ },
338
+ "nbformat": 4,
339
+ "nbformat_minor": 5
340
+ }
__pycache__/gradio.cpython-310.pyc ADDED
Binary file (790 Bytes). View file
 
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
2
+
3
+ from transformers import AdamW
4
+ import pandas as pd
5
+ import torch
6
+ import pytorch_lightning as pl
7
+ from pytorch_lightning.callbacks import ModelCheckpoint
8
+ from torch.nn.utils.rnn import pad_sequence
9
+ # from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler
10
+
11
+ pl.seed_everything(100)
12
+
13
+ MODEL_NAME='t5-base'
14
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+ INPUT_MAX_LEN = 128
16
+ OUTPUT_MAX_LEN = 128
17
+
18
+ tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, model_max_length=512)
19
+
20
+ class T5Model(pl.LightningModule):
21
+
22
+ def __init__(self):
23
+ super().__init__()
24
+ self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict = True)
25
+
26
+
27
+ def forward(self, input_ids, attention_mask, labels=None):
28
+
29
+ output = self.model(
30
+ input_ids=input_ids,
31
+ attention_mask=attention_mask,
32
+ labels=labels
33
+ )
34
+ return output.loss, output.logits
35
+
36
+ def training_step(self, batch, batch_idx):
37
+
38
+ input_ids = batch["input_ids"]
39
+ attention_mask = batch["attention_mask"]
40
+ labels= batch["target"]
41
+ loss, logits = self(input_ids , attention_mask, labels)
42
+
43
+
44
+ self.log("train_loss", loss, prog_bar=True, logger=True)
45
+
46
+ return {'loss': loss}
47
+
48
+ def validation_step(self, batch, batch_idx):
49
+ input_ids = batch["input_ids"]
50
+ attention_mask = batch["attention_mask"]
51
+ labels= batch["target"]
52
+ loss, logits = self(input_ids, attention_mask, labels)
53
+
54
+ self.log("val_loss", loss, prog_bar=True, logger=True)
55
+
56
+ return {'val_loss': loss}
57
+
58
+ def configure_optimizers(self):
59
+ return AdamW(self.parameters(), lr=0.0001)
60
+
61
+ train_model = T5Model.load_from_checkpoint('best-model.ckpt',map_location=DEVICE)
62
+ train_model.freeze()
63
+
64
+ def generate_question(question):
65
+
66
+ inputs_encoding = tokenizer(
67
+ question,
68
+ add_special_tokens=True,
69
+ max_length= INPUT_MAX_LEN,
70
+ padding = 'max_length',
71
+ truncation='only_first',
72
+ return_attention_mask=True,
73
+ return_tensors="pt"
74
+ )
75
+
76
+
77
+ generate_ids = train_model.model.generate(
78
+ input_ids = inputs_encoding["input_ids"],
79
+ attention_mask = inputs_encoding["attention_mask"],
80
+ max_length = INPUT_MAX_LEN,
81
+ num_beams = 4,
82
+ num_return_sequences = 1,
83
+ no_repeat_ngram_size=2,
84
+ early_stopping=True,
85
+ )
86
+
87
+ preds = [
88
+ tokenizer.decode(gen_id,
89
+ skip_special_tokens=True,
90
+ clean_up_tokenization_spaces=True)
91
+ for gen_id in generate_ids
92
+ ]
93
+
94
+ return "".join(preds)
95
+
96
+ import gradio as gr
97
+ import random
98
+ import time
99
+
100
+ with gr.Blocks() as demo:
101
+ chatbot = gr.Chatbot()
102
+ gr.Chatbot.style(chatbot,height=400)
103
+ msg = gr.Textbox(info="Press \'Enter\' to send")
104
+ clear = gr.Button("Clear")
105
+
106
+ def user(user_message, history):
107
+ return "", history + [[user_message, None]]
108
+
109
+ def bot(history):
110
+ bot_message = generate_question(history[-1][0])
111
+ history[-1][1] = ""
112
+ for character in bot_message:
113
+ history[-1][1] += character
114
+ time.sleep(0.05)
115
+ yield history
116
+
117
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=True).then(
118
+ bot, chatbot, chatbot
119
+ )
120
+ clear.click(lambda: None, None, chatbot, queue=True)
121
+
122
+ demo.queue(concurrency_count=2)
123
+ demo.launch()
best-model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d21c48743863e3b7355f0a432cf82b794091fa3ff2ad94c630bb3e9e2975b13
3
+ size 2675123255
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers==4.27.4
2
+ pandas==1.5.3
3
+ torch==2.0.0
4
+ pytorch-lightning==2.0.2
5
+ gradio==3.24.1