chatbot_demo / app.py
Daeyongkwon98's picture
Update app.py
e3ab0ad verified
raw
history blame
3.1 kB
import os
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import torch
from string import Template
from huggingface_hub import login
# Hugging Face์— ๋กœ๊ทธ์ธ (ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ Access Token ๊ฐ€์ ธ์˜ค๊ธฐ)
login(os.getenv("ACCESS_TOKEN")) # ACCESS_TOKEN์„ ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ ๋ถˆ๋Ÿฌ์˜ด
# ํ”„๋กฌํ”„ํŠธ ํ…œํ”Œ๋ฆฟ ์„ค์ •
prompt_template = Template("Human: ${inst} </s> Assistant: ")
# ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
model_name = "meta-llama/Llama-3.2-1b-instruct" # ๋ชจ๋ธ ๊ฒฝ๋กœ
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="cpu").eval()
# ์ƒ์„ฑ ์„ค์ • (Gradio UI์—์„œ ์ œ์–ดํ•  ์ˆ˜ ์žˆ๋Š” ๋ณ€์ˆ˜๋“ค)
default_generation_config = GenerationConfig(
temperature=0.1,
top_k=30,
top_p=0.5,
do_sample=True,
num_beams=1,
repetition_penalty=1.1,
min_new_tokens=10,
max_new_tokens=30
)
# ์‘๋‹ต ์ƒ์„ฑ ํ•จ์ˆ˜
def respond(message, history, system_message, max_tokens, temperature, top_p):
# ์ƒ์„ฑ ์„ค์ •
generation_config = GenerationConfig(
**default_generation_config.to_dict() # ๊ธฐ๋ณธ ์„ค์ •๊ณผ ๋ณ‘ํ•ฉ
)
generation_config.max_new_tokens = max_tokens # max_tokens ๋”ฐ๋กœ ์„ค์ •
generation_config.temperature = temperature # temperature ๋”ฐ๋กœ ์„ค์ •
generation_config.top_p = top_p
# ๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ์™€ ์‹œ์Šคํ…œ ๋ฉ”์‹œ์ง€๋ฅผ ํฌํ•จํ•œ ํ”„๋กฌํ”„ํŠธ ๊ตฌ์„ฑ
prompt = prompt_template.safe_substitute({"inst": system_message})
for val in history:
if val[0]:
prompt += f"Human: {val[0]} </s> Assistant: {val[1]} </s> "
prompt += f"Human: {message} </s> Assistant: "
# ๋ชจ๋ธ ์ž…๋ ฅ ์ƒ์„ฑ
inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
response_ids = model.generate(
**inputs,
generation_config=generation_config,
eos_token_id=tokenizer.eos_token_id, # ์ข…๋ฃŒ ํ† ํฐ ์„ค์ •
pad_token_id=tokenizer.eos_token_id # pad_token_id๋„ ์ข…๋ฃŒ ํ† ํฐ์œผ๋กœ ์„ค์ •
)
# ๋ชจ๋ธ ์‘๋‹ต ๋””์ฝ”๋”ฉ
response_text = tokenizer.decode(response_ids[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
# ์‹ค์‹œ๊ฐ„ ์‘๋‹ต์„ ์œ„ํ•œ ๋ถ€๋ถ„์  ํ…์ŠคํŠธ ๋ฐ˜ํ™˜
response = ""
for token in response_text:
response += token
yield response
# Gradio Chat Interface ์„ค์ •
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a friendly and knowledgeable assistant who can discuss a wide range of topics related to music, including genres, artists, albums, instruments, and music history.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=30, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.1, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.05, label="Top-p (nucleus sampling)"),
],
)
if __name__ == "__main__":
demo.launch()