|
from pathlib import Path |
|
|
|
import gradio as gr |
|
import lightning as L |
|
import torch |
|
|
|
from lit_llama import LLaMA, Tokenizer |
|
from lit_llama.utils import EmptyInitOnDevice |
|
|
|
|
|
class ChatBot: |
|
def __init__(self, model, tokenizer, fabric): |
|
self.model = model |
|
self.tokenizer = tokenizer |
|
self.fabric = fabric |
|
|
|
def generate_prompt(self, example): |
|
if example["input"]: |
|
return ( |
|
"μλλ μμ
μ μ€λͺ
νλ λͺ
λ Ήμ΄μ μΆκ°μ λ§₯λ½μ μ 곡νλ μ
λ ₯μ΄ μ§μ μ΄λ£¨λ μμ μ
λλ€.\n\n" |
|
"μμ²μ μ μ ν μλ£νλ μλ΅μ μμ±νμΈμ.\n\n" |
|
f"### λͺ
λ Ήμ΄:\n{example['instruction']}\n\n### μ
λ ₯:\n{example['input']}\n\n### μλ΅:" |
|
) |
|
return ( |
|
"νμκ° μμ¬μκ² μν κ³³μ λν΄ λ¬Έμν©λλ€.\n\n" |
|
"νμμ λ¬Έμ λ΄μ©μ λν΄ λ΅λ³νμΈμ. νμμ μ§λ³μ μ§λ¨νκ³ , κ°λ₯νλ©΄ μ²λ°©μ νμΈμ. \n\n" |
|
f"### λ¬Έμ:\n{example['instruction']}\n\n### μλ΅:" |
|
) |
|
|
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
idx, |
|
max_new_tokens, |
|
max_seq_length=None, |
|
temperature=0.8, |
|
top_k=None, |
|
eos_id=None, |
|
repetition_penalty=1.1, |
|
): |
|
T = idx.size(0) |
|
T_new = T + max_new_tokens |
|
if max_seq_length is None: |
|
max_seq_length = min(T_new, self.model.config.block_size) |
|
|
|
device, dtype = idx.device, idx.dtype |
|
|
|
empty = torch.empty(T_new, dtype=dtype, device=device) |
|
empty[:T] = idx |
|
idx = empty |
|
input_pos = torch.arange(0, T, device=device) |
|
|
|
if idx.device.type == "xla": |
|
import torch_xla.core.xla_model as xm |
|
|
|
xm.mark_step() |
|
|
|
|
|
for _ in range(max_new_tokens): |
|
x = idx.index_select(0, input_pos).view(1, -1) |
|
|
|
|
|
logits = self.model(x, max_seq_length, input_pos) |
|
logits = logits[0, -1] / temperature |
|
|
|
|
|
if top_k is not None: |
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
|
logits = torch.where(logits < v[[-1]], -float("Inf"), logits) |
|
|
|
probs = torch.nn.functional.softmax(logits, dim=-1) |
|
idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype) |
|
|
|
|
|
input_pos = input_pos[-1:] + 1 |
|
|
|
if idx.device.type == "xla": |
|
xm.mark_step() |
|
|
|
|
|
idx = idx.index_copy(0, input_pos, idx_next) |
|
|
|
|
|
if idx_next == eos_id: |
|
return idx[:input_pos] |
|
|
|
return idx |
|
|
|
|
|
def ans(self, user_message, history, max_new_tokens, top_k, temperature): |
|
history = history + [[user_message, None]] |
|
instruction = history[-1][0].strip() |
|
sample = { "instruction" : instruction, "input" : None } |
|
prompt = self.generate_prompt(sample) |
|
encoded_prompt = self.tokenizer.encode(prompt, bos=True, eos=False, device=self.fabric.device) |
|
|
|
y = self.generate( |
|
idx=encoded_prompt, |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature, |
|
top_k=top_k, |
|
eos_id=self.tokenizer.eos_id |
|
) |
|
|
|
self.model.reset_cache() |
|
|
|
response = self.tokenizer.decode(y) |
|
response = response.split('μλ΅:')[1].strip() |
|
|
|
|
|
history[-1][1] = response |
|
return response |
|
|
|
def load_model(): |
|
|
|
|
|
torch.set_float32_matmul_precision("high") |
|
|
|
checkpoint_path = Path("checkpoints/lit-llama/7B/lit-llama.pth") |
|
tokenizer_path = Path("checkpoints/lit-llama/tokenizer.model") |
|
quantize = None |
|
|
|
fabric = L.Fabric(devices=1) |
|
dtype = torch.bfloat16 if fabric.device.type == "cuda" and torch.cuda.is_bf16_supported() else torch.float32 |
|
|
|
with EmptyInitOnDevice(device=fabric.device, dtype=dtype, quantization_mode=quantize): |
|
model = LLaMA.from_name("7B") |
|
|
|
checkpoint = torch.load(checkpoint_path) |
|
model.load_state_dict(checkpoint) |
|
|
|
model.eval() |
|
model = fabric.setup_module(model) |
|
|
|
tokenizer = Tokenizer(tokenizer_path) |
|
|
|
return model, tokenizer, fabric |
|
|
|
|
|
def setup_gradio_ui(chat_bot, css): |
|
gr.ChatInterface( |
|
fn=chat_bot.ans, |
|
css=css, |
|
textbox=gr.Textbox(placeholder="μ§λ¬Έμ μ
λ ₯ν΄μ£ΌμΈμ.", container=False, scale=7), |
|
chatbot=gr.Chatbot(height=600, value=[[None, "μλ
νμΈμ. 무μμ΄ κΆκΈνμ κ°μ?"]], avatar_images=["asset/human.png", "asset/bot.jpg"]), |
|
title="μλ£μ© μ±λ΄ λ°λͺ¨", |
|
theme='soft', |
|
examples=[["λν΅μ΄ λ무 μ¬ν΄μ."], ["λ°°κ° μνκ³ ν ν κ² κ°μμ."], ["νλ¦¬κ° λμ΄μ§ λ―μ΄ μνμ."]], |
|
submit_btn=gr.Button(value="μ μ‘", icon="send.png", elem_id="green"), |
|
retry_btn=gr.Button(value="λ€μ보λ΄κΈ° (μ¬μ§λ¬Έ)β©", elem_id="blue"), |
|
undo_btn=gr.Button(value="μ΄μ μ± μμ β", elem_id="blue"), |
|
clear_btn=gr.Button(value="μ μ± μμ π«", elem_id="blue"), |
|
additional_inputs=[ |
|
gr.Slider( |
|
minimum=1, |
|
maximum=512, |
|
step=1, |
|
value=512, |
|
label="max_new_tokens", |
|
info="μ΅λ μμ± κ°λ₯ ν ν° μ", |
|
interactive=True |
|
), |
|
|
|
gr.Slider( |
|
minimum=1, |
|
maximum=300, |
|
step=1, |
|
value=150, |
|
label="top_k", |
|
info="νλ₯ μ΄ κ°μ₯ λμ ν ν° kκ° μνλ§", |
|
interactive=True |
|
), |
|
|
|
gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
step=0.1, |
|
value=0.5, |
|
label="temperature", |
|
info="1μ κ°κΉμΈμλ‘ λ€μν λ΅λ³ μμ±", |
|
interactive=True |
|
) |
|
] |
|
).queue().launch() |
|
|
|
def main(): |
|
|
|
model, tokenizer, fabric = load_model() |
|
|
|
|
|
chat_bot = ChatBot(model, tokenizer, fabric) |
|
|
|
|
|
css = """ |
|
#green {background-color: #00EF91} |
|
#blue {background-color: #B9E2FA} |
|
""" |
|
setup_gradio_ui(chat_bot, css) |
|
|
|
if __name__ == "__main__": |
|
main() |