Update app.py
Browse files
app.py
CHANGED
|
@@ -4,26 +4,30 @@ import torch
|
|
| 4 |
from threading import Thread
|
| 5 |
|
| 6 |
# --- 1. λͺ¨λΈ λ‘λ (Spaceμ GPU νμ©) ---
|
| 7 |
-
#
|
| 8 |
-
MODEL_NAME = "kakaocorp/kanana-1.5-2.1b-instruct-
|
| 9 |
|
| 10 |
print(f"λͺ¨λΈμ λ‘λ© μ€μ
λλ€: {MODEL_NAME} (Space GPU μ¬μ©)")
|
| 11 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 12 |
model = AutoModelForCausalLM.from_pretrained(
|
| 13 |
MODEL_NAME,
|
| 14 |
-
|
| 15 |
-
device_map="auto"
|
| 16 |
)
|
| 17 |
print("λͺ¨λΈ λ‘λ© μλ£!")
|
| 18 |
|
| 19 |
# --- 2. μ±λ΄ μλ΅ ν¨μ (Gradioκ° μ΄ ν¨μλ₯Ό νΈμΆ) ---
|
|
|
|
|
|
|
| 20 |
def predict(message, history):
|
| 21 |
|
| 22 |
# Kananaμ ν둬ννΈ νμ: <bos>user\n{prompt}\n<eos>assistant\n
|
| 23 |
history_prompt = ""
|
|
|
|
| 24 |
for user_msg, assistant_msg in history:
|
| 25 |
history_prompt += f"<bos>user\n{user_msg}\n<eos>assistant\n{assistant_msg}\n"
|
| 26 |
|
|
|
|
| 27 |
final_prompt = history_prompt + f"<bos>user\n{message}\n<eos>assistant\n"
|
| 28 |
|
| 29 |
inputs = tokenizer(final_prompt, return_tensors="pt").to(model.device)
|
|
@@ -31,35 +35,37 @@ def predict(message, history):
|
|
| 31 |
# --- μ€μκ° νμ΄ν ν¨κ³Ό(μ€νΈλ¦¬λ°)λ₯Ό μν μ€μ ---
|
| 32 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 33 |
|
|
|
|
| 34 |
generation_kwargs = dict(
|
| 35 |
-
inputs,
|
| 36 |
streamer=streamer,
|
| 37 |
-
max_new_tokens=1024,
|
| 38 |
eos_token_id=tokenizer.eos_token_id,
|
| 39 |
pad_token_id=tokenizer.pad_token_id,
|
| 40 |
-
temperature=0.7,
|
| 41 |
-
do_sample=True
|
| 42 |
)
|
| 43 |
|
| 44 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
| 45 |
thread.start()
|
| 46 |
|
|
|
|
| 47 |
generated_text = ""
|
| 48 |
for new_text in streamer:
|
| 49 |
generated_text += new_text
|
| 50 |
-
yield generated_text
|
| 51 |
|
| 52 |
# --- 3. Gradio μ±λ΄ UI μμ± ---
|
| 53 |
-
#
|
| 54 |
chatbot_ui = gr.ChatInterface(
|
| 55 |
fn=predict, # μ±λ΄μ΄ μ¬μ©ν ν¨μ
|
| 56 |
title="Kanana 1.5 μ±λ΄ ν
μ€νΈ π€",
|
| 57 |
description=f"{MODEL_NAME} λͺ¨λΈμ ν
μ€νΈν©λλ€.",
|
| 58 |
-
theme="soft",
|
| 59 |
-
examples=[["νκ΅μ μλλ μ΄λμΌ?"], ["AIμ λν΄ 3μ€λ‘ μμ½ν΄μ€."]]
|
| 60 |
-
# retry_btn
|
| 61 |
-
# undo_btn="μ΄μ λν μμ ", <-- μ΄ λΆλΆλ μ΅μ λ²μ μμ μ΄λ¦μ΄ λ€λ₯Ό μ μμ΄ μ κ±°
|
| 62 |
-
clear_btn="μ 체 λν μ΄κΈ°ν" # 'clear_btn'μ μμ§ μ ν¨ν©λλ€.
|
| 63 |
)
|
| 64 |
|
| 65 |
-
# ---
|
|
|
|
|
|
|
|
|
| 4 |
from threading import Thread
|
| 5 |
|
| 6 |
# --- 1. λͺ¨λΈ λ‘λ (Spaceμ GPU νμ©) ---
|
| 7 |
+
# 2505 λͺ¨λΈμ μμ§ μ‘΄μ¬νμ§ μμ, νμ¬ μ΅μ λͺ¨λΈμΈ 2405λ‘ μμ νμ΅λλ€.
|
| 8 |
+
MODEL_NAME = "kakaocorp/kanana-1.5-2.1b-instruct-2405"
|
| 9 |
|
| 10 |
print(f"λͺ¨λΈμ λ‘λ© μ€μ
λλ€: {MODEL_NAME} (Space GPU μ¬μ©)")
|
| 11 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 12 |
model = AutoModelForCausalLM.from_pretrained(
|
| 13 |
MODEL_NAME,
|
| 14 |
+
torch_dtype=torch.float16, # λ©λͺ¨λ¦¬ μ μ½μ μν΄ 16λΉνΈ μ¬μ©
|
| 15 |
+
device_map="auto" # μ€μ: μμμ GPUμ ν λΉ
|
| 16 |
)
|
| 17 |
print("λͺ¨λΈ λ‘λ© μλ£!")
|
| 18 |
|
| 19 |
# --- 2. μ±λ΄ μλ΅ ν¨μ (Gradioκ° μ΄ ν¨μλ₯Ό νΈμΆ) ---
|
| 20 |
+
# message: μ¬μ©μκ° μ
λ ₯ν λ©μμ§
|
| 21 |
+
# history: μ΄μ λν κΈ°λ‘ (Gradioκ° μλμΌλ‘ κ΄λ¦¬)
|
| 22 |
def predict(message, history):
|
| 23 |
|
| 24 |
# Kananaμ ν둬ννΈ νμ: <bos>user\n{prompt}\n<eos>assistant\n
|
| 25 |
history_prompt = ""
|
| 26 |
+
# μ΄μ λν κΈ°λ‘(history)μ Kanana ν둬ννΈ νμμΌλ‘ λ³ν
|
| 27 |
for user_msg, assistant_msg in history:
|
| 28 |
history_prompt += f"<bos>user\n{user_msg}\n<eos>assistant\n{assistant_msg}\n"
|
| 29 |
|
| 30 |
+
# νμ¬ λ©μμ§λ₯Ό ν둬ννΈμ μΆκ°
|
| 31 |
final_prompt = history_prompt + f"<bos>user\n{message}\n<eos>assistant\n"
|
| 32 |
|
| 33 |
inputs = tokenizer(final_prompt, return_tensors="pt").to(model.device)
|
|
|
|
| 35 |
# --- μ€μκ° νμ΄ν ν¨κ³Ό(μ€νΈλ¦¬λ°)λ₯Ό μν μ€μ ---
|
| 36 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 37 |
|
| 38 |
+
# λͺ¨λΈ μμ±(generate) μμ
μ λ³λ μ€λ λμμ μ€ν
|
| 39 |
generation_kwargs = dict(
|
| 40 |
+
**inputs, # inputs λμ
λ리μ λͺ¨λ ν€-κ° μμ μΈμλ‘ μ λ¬
|
| 41 |
streamer=streamer,
|
| 42 |
+
max_new_tokens=1024, # μ΅λ μμ± ν ν° μ
|
| 43 |
eos_token_id=tokenizer.eos_token_id,
|
| 44 |
pad_token_id=tokenizer.pad_token_id,
|
| 45 |
+
temperature=0.7, # μ°½μμ± μ‘°μ
|
| 46 |
+
do_sample=True # μνλ§ μ¬μ©
|
| 47 |
)
|
| 48 |
|
| 49 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
| 50 |
thread.start()
|
| 51 |
|
| 52 |
+
# μ€νΈλ¦¬λ¨Έμμ λμ€λ ν
μ€νΈλ₯Ό λ°λ‘λ°λ‘ λ°ν (yield)
|
| 53 |
generated_text = ""
|
| 54 |
for new_text in streamer:
|
| 55 |
generated_text += new_text
|
| 56 |
+
yield generated_text # ν
μ€νΈλ₯Ό ν κΈμμ© μ€μκ°μΌλ‘ 보λ
|
| 57 |
|
| 58 |
# --- 3. Gradio μ±λ΄ UI μμ± ---
|
| 59 |
+
# gr.ChatInterfaceλ₯Ό μ°λ©΄ UIκ° μ±λ΄ ννλ‘ μλ μμ±λ©λλ€.
|
| 60 |
chatbot_ui = gr.ChatInterface(
|
| 61 |
fn=predict, # μ±λ΄μ΄ μ¬μ©ν ν¨μ
|
| 62 |
title="Kanana 1.5 μ±λ΄ ν
μ€νΈ π€",
|
| 63 |
description=f"{MODEL_NAME} λͺ¨λΈμ ν
μ€νΈν©λλ€.",
|
| 64 |
+
theme="soft", # ν
λ§ μ€μ
|
| 65 |
+
examples=[["νκ΅μ μλλ μ΄λμΌ?"], ["AIμ λν΄ 3μ€λ‘ μμ½ν΄μ€."]]
|
| 66 |
+
# retry_btn, undo_btn, clear_btn νλΌλ―Έν°λ νμ¬ Gradio λ²μ μμ μ§μλμ§ μμ μμ νμ΅λλ€.
|
|
|
|
|
|
|
| 67 |
)
|
| 68 |
|
| 69 |
+
# --- 4. μ± μ€ν ---
|
| 70 |
+
# .launch()λ‘ Spaceμμ μ±μ μ€νμν΅λλ€.
|
| 71 |
+
chatbot_ui.launch()
|