med_kr / app.py
1q2w3e4r5t's picture
Update app.py
4bc1355
from pathlib import Path
import gradio as gr
import lightning as L
import torch
from lit_llama import LLaMA, Tokenizer
from lit_llama.utils import EmptyInitOnDevice
class ChatBot:
def __init__(self, model, tokenizer, fabric):
self.model = model
self.tokenizer = tokenizer
self.fabric = fabric
def generate_prompt(self, example):
if example["input"]:
return (
"μ•„λž˜λŠ” μž‘μ—…μ„ μ„€λͺ…ν•˜λŠ” λͺ…령어와 좔가적 λ§₯락을 μ œκ³΅ν•˜λŠ” μž…λ ₯이 짝을 μ΄λ£¨λŠ” μ˜ˆμ œμž…λ‹ˆλ‹€.\n\n"
"μš”μ²­μ„ 적절히 μ™„λ£Œν•˜λŠ” 응닡을 μž‘μ„±ν•˜μ„Έμš”.\n\n"
f"### λͺ…λ Ήμ–΄:\n{example['instruction']}\n\n### μž…λ ₯:\n{example['input']}\n\n### 응닡:"
)
return (
"ν™˜μžκ°€ μ˜μ‚¬μ—κ²Œ μ•„ν”ˆ 곳에 λŒ€ν•΄ λ¬Έμ˜ν•©λ‹ˆλ‹€.\n\n"
"ν™˜μžμ˜ 문의 λ‚΄μš©μ— λŒ€ν•΄ λ‹΅λ³€ν•˜μ„Έμš”. ν™˜μžμ˜ μ§ˆλ³‘μ„ μ§„λ‹¨ν•˜κ³ , κ°€λŠ₯ν•˜λ©΄ μ²˜λ°©μ„ ν•˜μ„Έμš”. \n\n"
f"### 문의:\n{example['instruction']}\n\n### 응닡:"
)
# default generation
@torch.no_grad()
def generate(
self,
idx,
max_new_tokens,
max_seq_length=None,
temperature=0.8,
top_k=None,
eos_id=None,
repetition_penalty=1.1,
):
T = idx.size(0)
T_new = T + max_new_tokens
if max_seq_length is None:
max_seq_length = min(T_new, self.model.config.block_size)
device, dtype = idx.device, idx.dtype
# create an empty tensor of the expected final shape and fill in the current tokens
empty = torch.empty(T_new, dtype=dtype, device=device)
empty[:T] = idx
idx = empty
input_pos = torch.arange(0, T, device=device)
if idx.device.type == "xla":
import torch_xla.core.xla_model as xm
xm.mark_step()
# generate max_new_tokens tokens
for _ in range(max_new_tokens):
x = idx.index_select(0, input_pos).view(1, -1)
# forward
logits = self.model(x, max_seq_length, input_pos)
logits = logits[0, -1] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits = torch.where(logits < v[[-1]], -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype)
# advance
input_pos = input_pos[-1:] + 1
if idx.device.type == "xla":
xm.mark_step()
# concatenate the new generation
idx = idx.index_copy(0, input_pos, idx_next)
# if <eos> token is triggered, return the output (stop generation)
if idx_next == eos_id:
return idx[:input_pos] # include the EOS token
return idx
# LLM generation ν•¨μˆ˜
def ans(self, user_message, history, max_new_tokens, top_k, temperature):
history = history + [[user_message, None]]
instruction = history[-1][0].strip()
sample = { "instruction" : instruction, "input" : None }
prompt = self.generate_prompt(sample)
encoded_prompt = self.tokenizer.encode(prompt, bos=True, eos=False, device=self.fabric.device)
y = self.generate(
idx=encoded_prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
eos_id=self.tokenizer.eos_id
)
self.model.reset_cache()
response = self.tokenizer.decode(y)
response = response.split('응닡:')[1].strip()
# history μ—…λ°μ΄νŠΈ
history[-1][1] = response
return response
def load_model():
# Settings for inference
# Precision setting for float32 matmul operations. It's important for some CUDA devices.
torch.set_float32_matmul_precision("high")
checkpoint_path = Path("checkpoints/lit-llama/7B/lit-llama.pth")
tokenizer_path = Path("checkpoints/lit-llama/tokenizer.model")
quantize = None # "gptq.int4" or "llm.int8"
fabric = L.Fabric(devices=1)
dtype = torch.bfloat16 if fabric.device.type == "cuda" and torch.cuda.is_bf16_supported() else torch.float32
with EmptyInitOnDevice(device=fabric.device, dtype=dtype, quantization_mode=quantize):
model = LLaMA.from_name("7B")
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint)
model.eval()
model = fabric.setup_module(model)
tokenizer = Tokenizer(tokenizer_path)
return model, tokenizer, fabric
# theme 'Taithrah/Minimal' 'abidlabs/dracula_test' 'JohnSmith9982/small_and_pretty'
def setup_gradio_ui(chat_bot, css):
gr.ChatInterface(
fn=chat_bot.ans,
css=css,
textbox=gr.Textbox(placeholder="μ§ˆλ¬Έμ„ μž…λ ₯ν•΄μ£Όμ„Έμš”.", container=False, scale=7),
chatbot=gr.Chatbot(height=600, value=[[None, "μ•ˆλ…•ν•˜μ„Έμš”. 무엇이 κΆκΈˆν•˜μ‹ κ°€μš”?"]], avatar_images=["asset/human.png", "asset/bot.jpg"]),
title="의료용 챗봇 데λͺ¨",
theme='soft',
examples=[["두톡이 λ„ˆλ¬΄ μ‹¬ν•΄μš”."], ["λ°°κ°€ μ•„ν”„κ³  토할것 κ°™μ•„μš”."], ["ν—ˆλ¦¬κ°€ λŠμ–΄μ§ˆ 듯이 μ•„νŒŒμš”."]],
submit_btn=gr.Button(value="전솑", icon="send.png", elem_id="green"),
retry_btn=gr.Button(value="λ‹€μ‹œλ³΄λ‚΄κΈ° (재질문)↩", elem_id="blue"),
undo_btn=gr.Button(value="이전챗 μ‚­μ œ ❌", elem_id="blue"),
clear_btn=gr.Button(value="μ „μ±— μ‚­μ œ πŸ’«", elem_id="blue"),
additional_inputs=[
gr.Slider(
minimum=1,
maximum=512,
step=1,
value=512,
label="max_new_tokens",
info="μ΅œλŒ€ 생성 κ°€λŠ₯ 토큰 수",
interactive=True
),
gr.Slider(
minimum=1,
maximum=300,
step=1,
value=150,
label="top_k",
info="ν™•λ₯ μ΄ κ°€μž₯ 높은 토큰 k개 μƒ˜ν”Œλ§",
interactive=True
),
gr.Slider(
minimum=0.1,
maximum=1.0,
step=0.1,
value=0.5,
label="temperature",
info="1에 κ°€κΉŒμšΈμˆ˜λ‘ λ‹€μ–‘ν•œ λ‹΅λ³€ 생성",
interactive=True
)
]
).queue().launch()
def main():
# λͺ¨λΈ, ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ
model, tokenizer, fabric = load_model()
# 챗봇 객체 생성
chat_bot = ChatBot(model, tokenizer, fabric)
# ui
css = """
#green {background-color: #00EF91}
#blue {background-color: #B9E2FA}
"""
setup_gradio_ui(chat_bot, css)
if __name__ == "__main__":
main()