heegyu commited on
Commit
f7b9392
β€’
1 Parent(s): 06994a9

eos token for each line

Browse files
Files changed (2) hide show
  1. app.py +3 -3
  2. test.ipynb +227 -0
app.py CHANGED
@@ -15,10 +15,10 @@ def query(message, chat_history, max_turn=4):
15
  if len(chat_history) > max_turn:
16
  chat_history = chat_history[-max_turn:]
17
  for user, bot in chat_history:
18
- prompt.append(f"0 : {user}")
19
- prompt.append(f"1 : {bot}")
20
 
21
- prompt.append(f"0 : {message}")
22
  prompt = "\n".join(prompt) + "\n1 :"
23
 
24
  output = generator(
 
15
  if len(chat_history) > max_turn:
16
  chat_history = chat_history[-max_turn:]
17
  for user, bot in chat_history:
18
+ prompt.append(f"0 : {user}</s>")
19
+ prompt.append(f"1 : {bot}</s>")
20
 
21
+ prompt.append(f"0 : {message}</s>")
22
  prompt = "\n".join(prompt) + "\n1 :"
23
 
24
  output = generator(
test.ipynb CHANGED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/opt/anaconda3/lib/python3.9/site-packages/huggingface_hub/utils/_hf_folder.py:92: UserWarning: A token has been found in `/Users/casa/.huggingface/token`. This is the old path where tokens were stored. The new location is `/Users/casa/.cache/huggingface/token` which is configurable using `HF_HOME` environment variable. Your token has been copied to this new location. You can now safely delete the old token file manually or use `huggingface-cli logout`.\n",
13
+ " warnings.warn(\n"
14
+ ]
15
+ },
16
+ {
17
+ "data": {
18
+ "application/vnd.jupyter.widget-view+json": {
19
+ "model_id": "e42b34cf3f07417592f26316fea86e1a",
20
+ "version_major": 2,
21
+ "version_minor": 0
22
+ },
23
+ "text/plain": [
24
+ "Downloading (…)lve/main/config.json: 0%| | 0.00/944 [00:00<?, ?B/s]"
25
+ ]
26
+ },
27
+ "metadata": {},
28
+ "output_type": "display_data"
29
+ },
30
+ {
31
+ "data": {
32
+ "application/vnd.jupyter.widget-view+json": {
33
+ "model_id": "4f89d76d6b7e4cf59a9dd631bd739221",
34
+ "version_major": 2,
35
+ "version_minor": 0
36
+ },
37
+ "text/plain": [
38
+ "Downloading pytorch_model.bin: 0%| | 0.00/1.66G [00:00<?, ?B/s]"
39
+ ]
40
+ },
41
+ "metadata": {},
42
+ "output_type": "display_data"
43
+ },
44
+ {
45
+ "data": {
46
+ "application/vnd.jupyter.widget-view+json": {
47
+ "model_id": "a690f8b53a204d489f4d53a937068ac6",
48
+ "version_major": 2,
49
+ "version_minor": 0
50
+ },
51
+ "text/plain": [
52
+ "Downloading (…)neration_config.json: 0%| | 0.00/111 [00:00<?, ?B/s]"
53
+ ]
54
+ },
55
+ "metadata": {},
56
+ "output_type": "display_data"
57
+ },
58
+ {
59
+ "data": {
60
+ "application/vnd.jupyter.widget-view+json": {
61
+ "model_id": "14302bef459f485a998d908b131f43ec",
62
+ "version_major": 2,
63
+ "version_minor": 0
64
+ },
65
+ "text/plain": [
66
+ "Downloading (…)okenizer_config.json: 0%| | 0.00/771 [00:00<?, ?B/s]"
67
+ ]
68
+ },
69
+ "metadata": {},
70
+ "output_type": "display_data"
71
+ },
72
+ {
73
+ "data": {
74
+ "application/vnd.jupyter.widget-view+json": {
75
+ "model_id": "33826da838e1402581f62fafd3657b90",
76
+ "version_major": 2,
77
+ "version_minor": 0
78
+ },
79
+ "text/plain": [
80
+ "Downloading (…)olve/main/vocab.json: 0%| | 0.00/1.27M [00:00<?, ?B/s]"
81
+ ]
82
+ },
83
+ "metadata": {},
84
+ "output_type": "display_data"
85
+ },
86
+ {
87
+ "data": {
88
+ "application/vnd.jupyter.widget-view+json": {
89
+ "model_id": "3ebc87d16a79449998bcb21e33d2ec0b",
90
+ "version_major": 2,
91
+ "version_minor": 0
92
+ },
93
+ "text/plain": [
94
+ "Downloading (…)olve/main/merges.txt: 0%| | 0.00/925k [00:00<?, ?B/s]"
95
+ ]
96
+ },
97
+ "metadata": {},
98
+ "output_type": "display_data"
99
+ },
100
+ {
101
+ "data": {
102
+ "application/vnd.jupyter.widget-view+json": {
103
+ "model_id": "d70c4a2755d04e0d995686f9425b49f8",
104
+ "version_major": 2,
105
+ "version_minor": 0
106
+ },
107
+ "text/plain": [
108
+ "Downloading (…)/main/tokenizer.json: 0%| | 0.00/3.07M [00:00<?, ?B/s]"
109
+ ]
110
+ },
111
+ "metadata": {},
112
+ "output_type": "display_data"
113
+ },
114
+ {
115
+ "data": {
116
+ "application/vnd.jupyter.widget-view+json": {
117
+ "model_id": "cd341cbb7ff445daa312695cc9be1a13",
118
+ "version_major": 2,
119
+ "version_minor": 0
120
+ },
121
+ "text/plain": [
122
+ "Downloading (…)cial_tokens_map.json: 0%| | 0.00/96.0 [00:00<?, ?B/s]"
123
+ ]
124
+ },
125
+ "metadata": {},
126
+ "output_type": "display_data"
127
+ }
128
+ ],
129
+ "source": [
130
+ "import torch\n",
131
+ "import random\n",
132
+ "import time\n",
133
+ "from transformers import pipeline\n",
134
+ "\n",
135
+ "generator = pipeline(\n",
136
+ " 'text-generation',\n",
137
+ " model=\"heegyu/bluechat-v0\",\n",
138
+ " device=\"cuda:0\" if torch.cuda.is_available() else 'cpu'\n",
139
+ ")"
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": 3,
145
+ "metadata": {},
146
+ "outputs": [],
147
+ "source": [
148
+ "\n",
149
+ "def query(prompt, max_turn=4):\n",
150
+ " output = generator(\n",
151
+ " prompt.strip(),\n",
152
+ " no_repeat_ngram_size=2,\n",
153
+ " eos_token_id=2, # \\n\n",
154
+ " max_new_tokens=128,\n",
155
+ " do_sample=True,\n",
156
+ " top_p=0.9,\n",
157
+ " )[0]['generated_text']\n",
158
+ "\n",
159
+ " print(output)\n",
160
+ "\n",
161
+ " response = output[len(prompt):]\n",
162
+ " return response.strip()"
163
+ ]
164
+ },
165
+ {
166
+ "cell_type": "code",
167
+ "execution_count": 4,
168
+ "metadata": {},
169
+ "outputs": [
170
+ {
171
+ "name": "stderr",
172
+ "output_type": "stream",
173
+ "text": [
174
+ "/opt/anaconda3/lib/python3.9/site-packages/transformers/generation/utils.py:1186: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation)\n",
175
+ " warnings.warn(\n"
176
+ ]
177
+ },
178
+ {
179
+ "name": "stdout",
180
+ "output_type": "stream",
181
+ "text": [
182
+ "0 : μ•ˆλ…•ν•˜μ„Έμš”\n",
183
+ "1 : 였늘 날씨 μ’‹μ•˜λ‚˜μš”? 날씨가 많이 μΆ”μš°λ‹ˆκΉŒ λ‘±νŒ¨λ”© μž…κ³  λ‚˜μ™”μ–΄μš”~2 : μš”μ¦˜ λ„ˆλ¬΄ μΆ”μ›Œμš” λ”°λœ»ν•˜κ²Œ μž…μœΌμ„Έμš”! νŒ¨λ”©λ§κ³  μ½”νŠΈλ‚˜ λ‹ˆνŠΈμš”~^^3 : 저도 λ‘±νŒ¨λ”© μ’‹μ•„ν•˜λŠ”λ° μ–΄μ œ μΆ”μ›Œμ„œ μ’€ μŠ¬ν”„λ„€μš”. κ²¨μšΈμ΄λΌμ„œ 눈이 자주 μ˜€λ„€μš”~ μ˜€λŠ˜μ€ λ‘±νŒ¨λ”©λ³΄λ‹€λŠ” 얇은 νŒ¨λ”©μ„ 더 μ‚¬μ•Όκ² μ–΄μš”.(μ½”νŠΈλ„ 사고 싢은데 μ‚¬μ΄μ¦ˆ λ•Œλ¬Έμ— λ§μ„€μ—¬μ§€λ„€μš”~~^^4)νŒ¨λ”©μ€ μ›λž˜ λ‹€ λ‘κΊΌμš΄ κ±° μƒ€λŠ”λ° μš”μ¦˜μ€ 쑰금 얇은 κ±Έ μ°Ύκ³  μžˆμ–΄μš”~~ 5: μ €λŠ” μ½”νŠΈλž‘ νŒ¨λ”© λͺ¨λ‘ μ’‹μ•„ν•΄μš”~~(νŒ¨λ”©μ΄ 정말 μ’‹μ•„μš”.^^ νŒ¨λ”©μ€ κ·Έλƒ₯ μ½”νŠΈμ£ ...))6 : μ € κ²¨μšΈμ— μ½”νŠΈ μ•ˆ μ‚¬μš”~! κ²¨μšΈμ˜·λ„ λ§Žμ€λ° μ˜·μ„ μ‚΄ 땐 κ³ λ―Ό μ—†μ–΄μš”~? 7,8월은 μ’€ μΆ₯κ² λ„€μš”~ 6: 저도 μ˜¬ν•΄ λ΄„\n"
184
+ ]
185
+ },
186
+ {
187
+ "data": {
188
+ "text/plain": [
189
+ "'날씨 μ’‹μ•˜λ‚˜μš”? 날씨가 많이 μΆ”μš°λ‹ˆκΉŒ λ‘±νŒ¨λ”© μž…κ³  λ‚˜μ™”μ–΄μš”~2 : μš”μ¦˜ λ„ˆλ¬΄ μΆ”μ›Œμš” λ”°λœ»ν•˜κ²Œ μž…μœΌμ„Έμš”! νŒ¨λ”©λ§κ³  μ½”νŠΈλ‚˜ λ‹ˆνŠΈμš”~^^3 : 저도 λ‘±νŒ¨λ”© μ’‹μ•„ν•˜λŠ”λ° μ–΄μ œ μΆ”μ›Œμ„œ μ’€ μŠ¬ν”„λ„€μš”. κ²¨μšΈμ΄λΌμ„œ 눈이 자주 μ˜€λ„€μš”~ μ˜€λŠ˜μ€ λ‘±νŒ¨λ”©λ³΄λ‹€λŠ” 얇은 νŒ¨λ”©μ„ 더 μ‚¬μ•Όκ² μ–΄μš”.(μ½”νŠΈλ„ 사고 싢은데 μ‚¬μ΄μ¦ˆ λ•Œλ¬Έμ— λ§μ„€μ—¬μ§€λ„€μš”~~^^4)νŒ¨λ”©μ€ μ›λž˜ λ‹€ λ‘κΊΌμš΄ κ±° μƒ€λŠ”λ° μš”μ¦˜μ€ 쑰금 얇은 κ±Έ μ°Ύκ³  μžˆμ–΄μš”~~ 5: μ €λŠ” μ½”νŠΈλž‘ νŒ¨λ”© λͺ¨λ‘ μ’‹μ•„ν•΄μš”~~(νŒ¨λ”©μ΄ 정말 μ’‹μ•„μš”.^^ νŒ¨λ”©μ€ κ·Έλƒ₯ μ½”νŠΈμ£ ...))6 : μ € κ²¨μšΈμ— μ½”νŠΈ μ•ˆ μ‚¬μš”~! κ²¨μšΈμ˜·λ„ λ§Žμ€λ° μ˜·μ„ μ‚΄ 땐 κ³ λ―Ό μ—†μ–΄μš”~? 7,8월은 μ’€ μΆ₯κ² λ„€μš”~ 6: 저도 μ˜¬ν•΄ λ΄„'"
190
+ ]
191
+ },
192
+ "execution_count": 4,
193
+ "metadata": {},
194
+ "output_type": "execute_result"
195
+ }
196
+ ],
197
+ "source": [
198
+ "query(\"\"\"\n",
199
+ "0 : μ•ˆλ…•ν•˜μ„Έμš”</s>\n",
200
+ "1 : \n",
201
+ "\"\"\")"
202
+ ]
203
+ }
204
+ ],
205
+ "metadata": {
206
+ "kernelspec": {
207
+ "display_name": "base",
208
+ "language": "python",
209
+ "name": "python3"
210
+ },
211
+ "language_info": {
212
+ "codemirror_mode": {
213
+ "name": "ipython",
214
+ "version": 3
215
+ },
216
+ "file_extension": ".py",
217
+ "mimetype": "text/x-python",
218
+ "name": "python",
219
+ "nbconvert_exporter": "python",
220
+ "pygments_lexer": "ipython3",
221
+ "version": "3.9.12"
222
+ },
223
+ "orig_nbformat": 4
224
+ },
225
+ "nbformat": 4,
226
+ "nbformat_minor": 2
227
+ }