1q2w3e4r5t
commited on
Commit
β’
4bc1355
1
Parent(s):
bf62210
Update app.py
Browse files
app.py
CHANGED
@@ -8,7 +8,7 @@ from lit_llama import LLaMA, Tokenizer
|
|
8 |
from lit_llama.utils import EmptyInitOnDevice
|
9 |
|
10 |
|
11 |
-
class
|
12 |
def __init__(self, model, tokenizer, fabric):
|
13 |
self.model = model
|
14 |
self.tokenizer = tokenizer
|
@@ -26,17 +26,18 @@ class ChatDoctor:
|
|
26 |
"νμμ λ¬Έμ λ΄μ©μ λν΄ λ΅λ³νμΈμ. νμμ μ§λ³μ μ§λ¨νκ³ , κ°λ₯νλ©΄ μ²λ°©μ νμΈμ. \n\n"
|
27 |
f"### λ¬Έμ:\n{example['instruction']}\n\n### μλ΅:"
|
28 |
)
|
29 |
-
|
30 |
-
#
|
31 |
@torch.no_grad()
|
32 |
def generate(
|
33 |
-
self,
|
34 |
-
idx,
|
35 |
-
max_new_tokens,
|
36 |
max_seq_length=None,
|
37 |
-
temperature=0.8,
|
38 |
-
top_k=None,
|
39 |
-
eos_id=None
|
|
|
40 |
):
|
41 |
T = idx.size(0)
|
42 |
T_new = T + max_new_tokens
|
@@ -86,14 +87,9 @@ class ChatDoctor:
|
|
86 |
|
87 |
return idx
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
# The user's message is added to the history with None as the bot's response.
|
93 |
-
return "", history + [[user_message, None]]
|
94 |
-
|
95 |
-
# This method generates and handles bot's responses.
|
96 |
-
def bot(self, history, max_new_tokens, top_k, temperature):
|
97 |
instruction = history[-1][0].strip()
|
98 |
sample = { "instruction" : instruction, "input" : None }
|
99 |
prompt = self.generate_prompt(sample)
|
@@ -106,17 +102,15 @@ class ChatDoctor:
|
|
106 |
top_k=top_k,
|
107 |
eos_id=self.tokenizer.eos_id
|
108 |
)
|
109 |
-
|
110 |
self.model.reset_cache()
|
111 |
|
112 |
response = self.tokenizer.decode(y)
|
113 |
response = response.split('μλ΅:')[1].strip()
|
114 |
-
|
115 |
-
# The history is updated with the bot's response.
|
116 |
-
history[-1][1] = response
|
117 |
-
|
118 |
-
return history
|
119 |
|
|
|
|
|
|
|
120 |
|
121 |
def load_model():
|
122 |
# Settings for inference
|
@@ -143,82 +137,66 @@ def load_model():
|
|
143 |
|
144 |
return model, tokenizer, fabric
|
145 |
|
146 |
-
|
147 |
-
def setup_gradio_ui(
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
## Parameters
|
163 |
-
""")
|
164 |
-
|
165 |
-
max_new_tokens = gr.Slider(
|
166 |
-
minimum=1,
|
167 |
-
maximum=512,
|
168 |
-
step=1,
|
169 |
-
value=512,
|
170 |
-
label="max_new_tokens",
|
171 |
-
info="The number of new tokens to generate",
|
172 |
-
interactive=True
|
173 |
-
)
|
174 |
-
|
175 |
-
top_k = gr.Slider(
|
176 |
minimum=1,
|
177 |
-
maximum=
|
178 |
-
step=1,
|
179 |
-
value=
|
180 |
-
label="
|
181 |
-
info="
|
182 |
-
interactive=True
|
183 |
-
)
|
184 |
-
|
185 |
-
temperature = gr.Slider(
|
186 |
-
minimum=0.1,
|
187 |
-
maximum=1.0,
|
188 |
-
step=0.1,
|
189 |
-
value=0.8,
|
190 |
-
label="temperature",
|
191 |
-
info="Scales the predicted logits by 1 / temperature",
|
192 |
interactive=True
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
|
|
|
|
|
|
|
|
212 |
|
213 |
def main():
|
214 |
-
#
|
215 |
model, tokenizer, fabric = load_model()
|
216 |
|
217 |
-
#
|
218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
-
# Gradio UI setup and launch
|
221 |
-
setup_gradio_ui(chat_doctor)
|
222 |
-
|
223 |
if __name__ == "__main__":
|
224 |
-
main()
|
|
|
8 |
from lit_llama.utils import EmptyInitOnDevice
|
9 |
|
10 |
|
11 |
+
class ChatBot:
|
12 |
def __init__(self, model, tokenizer, fabric):
|
13 |
self.model = model
|
14 |
self.tokenizer = tokenizer
|
|
|
26 |
"νμμ λ¬Έμ λ΄μ©μ λν΄ λ΅λ³νμΈμ. νμμ μ§λ³μ μ§λ¨νκ³ , κ°λ₯νλ©΄ μ²λ°©μ νμΈμ. \n\n"
|
27 |
f"### λ¬Έμ:\n{example['instruction']}\n\n### μλ΅:"
|
28 |
)
|
29 |
+
|
30 |
+
# default generation
|
31 |
@torch.no_grad()
|
32 |
def generate(
|
33 |
+
self,
|
34 |
+
idx,
|
35 |
+
max_new_tokens,
|
36 |
max_seq_length=None,
|
37 |
+
temperature=0.8,
|
38 |
+
top_k=None,
|
39 |
+
eos_id=None,
|
40 |
+
repetition_penalty=1.1,
|
41 |
):
|
42 |
T = idx.size(0)
|
43 |
T_new = T + max_new_tokens
|
|
|
87 |
|
88 |
return idx
|
89 |
|
90 |
+
# LLM generation ν¨μ
|
91 |
+
def ans(self, user_message, history, max_new_tokens, top_k, temperature):
|
92 |
+
history = history + [[user_message, None]]
|
|
|
|
|
|
|
|
|
|
|
93 |
instruction = history[-1][0].strip()
|
94 |
sample = { "instruction" : instruction, "input" : None }
|
95 |
prompt = self.generate_prompt(sample)
|
|
|
102 |
top_k=top_k,
|
103 |
eos_id=self.tokenizer.eos_id
|
104 |
)
|
105 |
+
|
106 |
self.model.reset_cache()
|
107 |
|
108 |
response = self.tokenizer.decode(y)
|
109 |
response = response.split('μλ΅:')[1].strip()
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
+
# history μ
λ°μ΄νΈ
|
112 |
+
history[-1][1] = response
|
113 |
+
return response
|
114 |
|
115 |
def load_model():
|
116 |
# Settings for inference
|
|
|
137 |
|
138 |
return model, tokenizer, fabric
|
139 |
|
140 |
+
# theme 'Taithrah/Minimal' 'abidlabs/dracula_test' 'JohnSmith9982/small_and_pretty'
|
141 |
+
def setup_gradio_ui(chat_bot, css):
|
142 |
+
gr.ChatInterface(
|
143 |
+
fn=chat_bot.ans,
|
144 |
+
css=css,
|
145 |
+
textbox=gr.Textbox(placeholder="μ§λ¬Έμ μ
λ ₯ν΄μ£ΌμΈμ.", container=False, scale=7),
|
146 |
+
chatbot=gr.Chatbot(height=600, value=[[None, "μλ
νμΈμ. 무μμ΄ κΆκΈνμ κ°μ?"]], avatar_images=["asset/human.png", "asset/bot.jpg"]),
|
147 |
+
title="μλ£μ© μ±λ΄ λ°λͺ¨",
|
148 |
+
theme='soft',
|
149 |
+
examples=[["λν΅μ΄ λ무 μ¬ν΄μ."], ["λ°°κ° μνκ³ ν ν κ² κ°μμ."], ["νλ¦¬κ° λμ΄μ§ λ―μ΄ μνμ."]],
|
150 |
+
submit_btn=gr.Button(value="μ μ‘", icon="send.png", elem_id="green"),
|
151 |
+
retry_btn=gr.Button(value="λ€μ보λ΄κΈ° (μ¬μ§λ¬Έ)β©", elem_id="blue"),
|
152 |
+
undo_btn=gr.Button(value="μ΄μ μ± μμ β", elem_id="blue"),
|
153 |
+
clear_btn=gr.Button(value="μ μ± μμ π«", elem_id="blue"),
|
154 |
+
additional_inputs=[
|
155 |
+
gr.Slider(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
minimum=1,
|
157 |
+
maximum=512,
|
158 |
+
step=1,
|
159 |
+
value=512,
|
160 |
+
label="max_new_tokens",
|
161 |
+
info="μ΅λ μμ± κ°λ₯ ν ν° μ",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
interactive=True
|
163 |
+
),
|
164 |
+
|
165 |
+
gr.Slider(
|
166 |
+
minimum=1,
|
167 |
+
maximum=300,
|
168 |
+
step=1,
|
169 |
+
value=150,
|
170 |
+
label="top_k",
|
171 |
+
info="νλ₯ μ΄ κ°μ₯ λμ ν ν° kκ° μνλ§",
|
172 |
+
interactive=True
|
173 |
+
),
|
174 |
+
|
175 |
+
gr.Slider(
|
176 |
+
minimum=0.1,
|
177 |
+
maximum=1.0,
|
178 |
+
step=0.1,
|
179 |
+
value=0.5,
|
180 |
+
label="temperature",
|
181 |
+
info="1μ κ°κΉμΈμλ‘ λ€μν λ΅λ³ μμ±",
|
182 |
+
interactive=True
|
183 |
+
)
|
184 |
+
]
|
185 |
+
).queue().launch()
|
186 |
|
187 |
def main():
|
188 |
+
# λͺ¨λΈ, ν ν¬λμ΄μ λ‘λ
|
189 |
model, tokenizer, fabric = load_model()
|
190 |
|
191 |
+
# μ±λ΄ κ°μ²΄ μμ±
|
192 |
+
chat_bot = ChatBot(model, tokenizer, fabric)
|
193 |
+
|
194 |
+
# ui
|
195 |
+
css = """
|
196 |
+
#green {background-color: #00EF91}
|
197 |
+
#blue {background-color: #B9E2FA}
|
198 |
+
"""
|
199 |
+
setup_gradio_ui(chat_bot, css)
|
200 |
|
|
|
|
|
|
|
201 |
if __name__ == "__main__":
|
202 |
+
main()
|