JUNGU's picture
Update app.py
2165e54 verified
import gradio as gr
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
LlamaTokenizer,
)
import os
from threading import Thread
import spaces
import subprocess
# flash-attn ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์„ค์น˜. CUDA ๋นŒ๋“œ๋Š” ๊ฑด๋„ˆ๋œ€.
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
# Hugging Face ํ† ํฐ ๊ฐ€์ ธ์˜ค๊ธฐ
token = os.environ["HF_TOKEN"]
# apple/OpenELM-270M ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
# ํ† ํฌ๋‚˜์ด์ €๊ฐ€ ์˜ค๋ฅ˜๋‚˜๋Š” ๋ฌธ์ œ๊ฐ€ ์žˆ์–ด์„œ NousResearch/Llama-2-7b-hf๋ฅผ ์”€
# ํ•œ๊ตญ์–ด ๋ชจ๋ธ ํ† ํฌ๋‚˜์ด์ €๋กœ ๋ฐ”๊ฟ”๋ด„ beomi/llama-2-ko-7b
# apple/OpenELM-1.1B ํ† ํฌ๋‚˜์ด์ €๋งŒ ํฌ๊ฒŒ ํ•ด๋ด„ <- ์•ˆ๋จ
# apple/OpenELM-270M-Instruct๋กœ ๋‘˜๋‹ค ๋ณ€๊ฒฝ ํ•ด๋ด„ <- ์•ˆ๋จ
model = AutoModelForCausalLM.from_pretrained(
"apple/OpenELM-270M-Instruct",
token=token,
trust_remote_code=True,
)
tok = AutoTokenizer.from_pretrained(
"NousResearch/Llama-2-7b-hf",
token=token,
trust_remote_code=True,
tokenizer_class=LlamaTokenizer,
)
# ์ข…๋ฃŒ ํ† ํฐ ID ์„ค์ •
terminators = [
tok.eos_token_id,
]
# GPU๊ฐ€ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๊ฒฝ์šฐ GPU๋กœ, ์•„๋‹ˆ๋ฉด CPU๋กœ ๋ชจ๋ธ ๋กœ๋“œ
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"Using GPU: {torch.cuda.get_device_name(device)}")
else:
device = torch.device("cpu")
print("Using CPU")
model = model.to(device)
# Spaces์˜ GPU ์ž์›์„ ์‚ฌ์šฉํ•˜์—ฌ chat ํ•จ์ˆ˜ ์‹คํ–‰. ์ตœ๋Œ€ 60์ดˆ ๋™์•ˆ GPU ์ž์› ์‚ฌ์šฉ ๊ฐ€๋Šฅ.
@spaces.GPU(duration=60)
def chat(message, history, temperature, do_sample, max_tokens):
# ์ฑ„ํŒ… ๊ธฐ๋ก์„ ์ ์ ˆํ•œ ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜
chat = []
for item in history:
chat.append({"role": "user", "content": item[0]})
if item[1] is not None:
chat.append({"role": "assistant", "content": item[1]})
chat.append({"role": "user", "content": message})
# ํ† ํฌ๋‚˜์ด์ €๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ž…๋ ฅ ์ฒ˜๋ฆฌ
messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
model_inputs = tok([messages], return_tensors="pt").to(device)
# TextIteratorStreamer๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ ์ถœ๋ ฅ ์ŠคํŠธ๋ฆฌ๋ฐ
streamer = TextIteratorStreamer(
tok, timeout=20.0, skip_prompt=True, skip_special_tokens=True
)
# ์ƒ์„ฑ ๊ด€๋ จ ๋งค๊ฐœ๋ณ€์ˆ˜ ์„ค์ •
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=max_tokens, # ์ƒ์„ฑํ•  ์ตœ๋Œ€ ์ƒˆ ํ† ํฐ ์ˆ˜
do_sample=True, # ์ƒ˜ํ”Œ๋ง ์—ฌ๋ถ€
temperature=temperature, # ์˜จ๋„ ๋งค๊ฐœ๋ณ€์ˆ˜. ๋†’์„์ˆ˜๋ก ๋‹ค์–‘์„ฑ ์ฆ๊ฐ€
eos_token_id=terminators, # ์ข…๋ฃŒ ํ† ํฐ ID
)
# ์˜จ๋„๊ฐ€ 0์ด๋ฉด ์ƒ˜ํ”Œ๋งํ•˜์ง€ ์•Š์Œ
if temperature == 0:
generate_kwargs["do_sample"] = False
# ๋ณ„๋„ ์Šค๋ ˆ๋“œ์—์„œ ๋ชจ๋ธ ์ƒ์„ฑ ์‹œ์ž‘
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# ์ƒ์„ฑ๋œ ํ…์ŠคํŠธ๋ฅผ ๋ฐ˜๋ณต์ ์œผ๋กœ yield
partial_text = ""
for new_text in streamer:
partial_text += new_text
yield partial_text
yield partial_text
# Gradio์˜ ChatInterface๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋Œ€ํ™”ํ˜• ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
demo = gr.ChatInterface(
fn=chat,
examples=[["let's talk about korea"]],
additional_inputs_accordion=gr.Accordion(
label="โš™๏ธ Parameters", open=False, render=False
),
additional_inputs=[
gr.Slider(
minimum=0, maximum=1, step=0.1, value=0.7, label="Temperature", render=False
),
gr.Checkbox(label="Sampling", value=True),
gr.Slider(
minimum=128,
maximum=4096,
step=1,
value=512,
label="Max new tokens",
render=False,
),
],
stop_btn="Stop Generation",
title="Chat With LLMs",
description="Now Running [apple/OpenELM-270M](https://huggingface.co/apple/OpenELM-270M)",
)
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์‹คํ–‰
demo.launch()