1q2w3e4r5t commited on
Commit
4bc1355
β€’
1 Parent(s): bf62210

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -94
app.py CHANGED
@@ -8,7 +8,7 @@ from lit_llama import LLaMA, Tokenizer
8
  from lit_llama.utils import EmptyInitOnDevice
9
 
10
 
11
- class ChatDoctor:
12
  def __init__(self, model, tokenizer, fabric):
13
  self.model = model
14
  self.tokenizer = tokenizer
@@ -26,17 +26,18 @@ class ChatDoctor:
26
  "ν™˜μžμ˜ 문의 λ‚΄μš©μ— λŒ€ν•΄ λ‹΅λ³€ν•˜μ„Έμš”. ν™˜μžμ˜ μ§ˆλ³‘μ„ μ§„λ‹¨ν•˜κ³ , κ°€λŠ₯ν•˜λ©΄ μ²˜λ°©μ„ ν•˜μ„Έμš”. \n\n"
27
  f"### 문의:\n{example['instruction']}\n\n### 응닡:"
28
  )
29
-
30
- # This method generates the chatbot's responses.
31
  @torch.no_grad()
32
  def generate(
33
- self,
34
- idx,
35
- max_new_tokens,
36
  max_seq_length=None,
37
- temperature=0.8,
38
- top_k=None,
39
- eos_id=None
 
40
  ):
41
  T = idx.size(0)
42
  T_new = T + max_new_tokens
@@ -86,14 +87,9 @@ class ChatDoctor:
86
 
87
  return idx
88
 
89
-
90
- # This method handles user's messages and updates the conversation history.
91
- def user(self, user_message, history):
92
- # The user's message is added to the history with None as the bot's response.
93
- return "", history + [[user_message, None]]
94
-
95
- # This method generates and handles bot's responses.
96
- def bot(self, history, max_new_tokens, top_k, temperature):
97
  instruction = history[-1][0].strip()
98
  sample = { "instruction" : instruction, "input" : None }
99
  prompt = self.generate_prompt(sample)
@@ -106,17 +102,15 @@ class ChatDoctor:
106
  top_k=top_k,
107
  eos_id=self.tokenizer.eos_id
108
  )
109
-
110
  self.model.reset_cache()
111
 
112
  response = self.tokenizer.decode(y)
113
  response = response.split('응닡:')[1].strip()
114
-
115
- # The history is updated with the bot's response.
116
- history[-1][1] = response
117
-
118
- return history
119
 
 
 
 
120
 
121
  def load_model():
122
  # Settings for inference
@@ -143,82 +137,66 @@ def load_model():
143
 
144
  return model, tokenizer, fabric
145
 
146
-
147
- def setup_gradio_ui(chat_doctor):
148
- with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
149
- gr.Markdown(
150
- """
151
- # ChatDoctor-KR Demo
152
-
153
- last modified : 23.05.18
154
- """)
155
-
156
- chatbot = gr.Chatbot(label="ChatDoctor-KR")
157
- msg = gr.Textbox(lines=1, placeholder="질문 μž…λ ₯ ν›„ μ—”ν„°λ₯Ό λˆ„λ₯΄μ„Έμš”.", label="질문")
158
- clear = gr.Button("클리어")
159
-
160
- gr.Markdown(
161
- """
162
- ## Parameters
163
- """)
164
-
165
- max_new_tokens = gr.Slider(
166
- minimum=1,
167
- maximum=512,
168
- step=1,
169
- value=512,
170
- label="max_new_tokens",
171
- info="The number of new tokens to generate",
172
- interactive=True
173
- )
174
-
175
- top_k = gr.Slider(
176
  minimum=1,
177
- maximum=300,
178
- step=1,
179
- value=200,
180
- label="top_k",
181
- info="If specified, only sample among the tokens with the k highest probabilities",
182
- interactive=True
183
- )
184
-
185
- temperature = gr.Slider(
186
- minimum=0.1,
187
- maximum=1.0,
188
- step=0.1,
189
- value=0.8,
190
- label="temperature",
191
- info="Scales the predicted logits by 1 / temperature",
192
  interactive=True
193
- )
194
-
195
- with gr.Accordion(label="Open for More!", open=False):
196
- gr.Markdown("Blah Blah ...")
197
-
198
- submit_result = msg.submit(
199
- chat_doctor.user, [msg, chatbot], [msg, chatbot], queue=False
200
- )
201
- submit_result.then(
202
- chat_doctor.bot, [chatbot, max_new_tokens, top_k, temperature], chatbot
203
- )
204
-
205
- # This part clears the chatbot history when the clear button is clicked.
206
- clear.click(lambda: None, None, chatbot, queue=False)
207
-
208
- demo.queue()
209
-
210
- demo.launch(share=True, server_name="0.0.0.0")
211
-
 
 
 
 
212
 
213
  def main():
214
- # Load model and tokenizer
215
  model, tokenizer, fabric = load_model()
216
 
217
- # ChatDoctor instance
218
- chat_doctor = ChatDoctor(model, tokenizer, fabric)
 
 
 
 
 
 
 
219
 
220
- # Gradio UI setup and launch
221
- setup_gradio_ui(chat_doctor)
222
-
223
  if __name__ == "__main__":
224
- main()
 
8
  from lit_llama.utils import EmptyInitOnDevice
9
 
10
 
11
+ class ChatBot:
12
  def __init__(self, model, tokenizer, fabric):
13
  self.model = model
14
  self.tokenizer = tokenizer
 
26
  "ν™˜μžμ˜ 문의 λ‚΄μš©μ— λŒ€ν•΄ λ‹΅λ³€ν•˜μ„Έμš”. ν™˜μžμ˜ μ§ˆλ³‘μ„ μ§„λ‹¨ν•˜κ³ , κ°€λŠ₯ν•˜λ©΄ μ²˜λ°©μ„ ν•˜μ„Έμš”. \n\n"
27
  f"### 문의:\n{example['instruction']}\n\n### 응닡:"
28
  )
29
+
30
+ # default generation
31
  @torch.no_grad()
32
  def generate(
33
+ self,
34
+ idx,
35
+ max_new_tokens,
36
  max_seq_length=None,
37
+ temperature=0.8,
38
+ top_k=None,
39
+ eos_id=None,
40
+ repetition_penalty=1.1,
41
  ):
42
  T = idx.size(0)
43
  T_new = T + max_new_tokens
 
87
 
88
  return idx
89
 
90
+ # LLM generation ν•¨μˆ˜
91
+ def ans(self, user_message, history, max_new_tokens, top_k, temperature):
92
+ history = history + [[user_message, None]]
 
 
 
 
 
93
  instruction = history[-1][0].strip()
94
  sample = { "instruction" : instruction, "input" : None }
95
  prompt = self.generate_prompt(sample)
 
102
  top_k=top_k,
103
  eos_id=self.tokenizer.eos_id
104
  )
105
+
106
  self.model.reset_cache()
107
 
108
  response = self.tokenizer.decode(y)
109
  response = response.split('응닡:')[1].strip()
 
 
 
 
 
110
 
111
+ # history μ—…λ°μ΄νŠΈ
112
+ history[-1][1] = response
113
+ return response
114
 
115
  def load_model():
116
  # Settings for inference
 
137
 
138
  return model, tokenizer, fabric
139
 
140
+ # theme 'Taithrah/Minimal' 'abidlabs/dracula_test' 'JohnSmith9982/small_and_pretty'
141
+ def setup_gradio_ui(chat_bot, css):
142
+ gr.ChatInterface(
143
+ fn=chat_bot.ans,
144
+ css=css,
145
+ textbox=gr.Textbox(placeholder="μ§ˆλ¬Έμ„ μž…λ ₯ν•΄μ£Όμ„Έμš”.", container=False, scale=7),
146
+ chatbot=gr.Chatbot(height=600, value=[[None, "μ•ˆλ…•ν•˜μ„Έμš”. 무엇이 κΆκΈˆν•˜μ‹ κ°€μš”?"]], avatar_images=["asset/human.png", "asset/bot.jpg"]),
147
+ title="의료용 챗봇 데λͺ¨",
148
+ theme='soft',
149
+ examples=[["두톡이 λ„ˆλ¬΄ μ‹¬ν•΄μš”."], ["λ°°κ°€ μ•„ν”„κ³  토할것 κ°™μ•„μš”."], ["ν—ˆλ¦¬κ°€ λŠμ–΄μ§ˆ 듯이 μ•„νŒŒμš”."]],
150
+ submit_btn=gr.Button(value="전솑", icon="send.png", elem_id="green"),
151
+ retry_btn=gr.Button(value="λ‹€μ‹œλ³΄λ‚΄κΈ° (재질문)↩", elem_id="blue"),
152
+ undo_btn=gr.Button(value="이전챗 μ‚­μ œ ❌", elem_id="blue"),
153
+ clear_btn=gr.Button(value="μ „μ±— μ‚­μ œ πŸ’«", elem_id="blue"),
154
+ additional_inputs=[
155
+ gr.Slider(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  minimum=1,
157
+ maximum=512,
158
+ step=1,
159
+ value=512,
160
+ label="max_new_tokens",
161
+ info="μ΅œλŒ€ 생성 κ°€λŠ₯ 토큰 수",
 
 
 
 
 
 
 
 
 
 
162
  interactive=True
163
+ ),
164
+
165
+ gr.Slider(
166
+ minimum=1,
167
+ maximum=300,
168
+ step=1,
169
+ value=150,
170
+ label="top_k",
171
+ info="ν™•λ₯ μ΄ κ°€μž₯ 높은 토큰 k개 μƒ˜ν”Œλ§",
172
+ interactive=True
173
+ ),
174
+
175
+ gr.Slider(
176
+ minimum=0.1,
177
+ maximum=1.0,
178
+ step=0.1,
179
+ value=0.5,
180
+ label="temperature",
181
+ info="1에 κ°€κΉŒμšΈμˆ˜λ‘ λ‹€μ–‘ν•œ λ‹΅λ³€ 생성",
182
+ interactive=True
183
+ )
184
+ ]
185
+ ).queue().launch()
186
 
187
  def main():
188
+ # λͺ¨λΈ, ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ
189
  model, tokenizer, fabric = load_model()
190
 
191
+ # 챗봇 객체 생성
192
+ chat_bot = ChatBot(model, tokenizer, fabric)
193
+
194
+ # ui
195
+ css = """
196
+ #green {background-color: #00EF91}
197
+ #blue {background-color: #B9E2FA}
198
+ """
199
+ setup_gradio_ui(chat_bot, css)
200
 
 
 
 
201
  if __name__ == "__main__":
202
+ main()