File size: 4,042 Bytes
b8c24aa
3a82207
63b82b4
 
 
 
42799d2
63b82b4
c8fdb3b
3a82207
4e81072
deaeb85
00f3401
33cc946
00f3401
 
 
 
 
08c1bd3
33cc946
4e81072
7dc3087
b693a74
42799d2
 
13bee58
ecf6383
63b82b4
81e5bac
00f3401
 
63b82b4
895beee
ecf6383
895beee
 
42799d2
895beee
33cc946
 
ea9c0d3
7115ad7
ea9c0d3
7dc3087
33cc946
64d8a64
63b82b4
64d8a64
 
63b82b4
64d8a64
63b82b4
c7f7d96
08c1bd3
33cc946
a6b8174
00f3401
33cc946
3a82207
 
 
 
 
 
33cc946
 
3a82207
 
33cc946
 
3a82207
a6b8174
63b82b4
33cc946
 
3a82207
 
 
33cc946
 
 
 
3a82207
00f3401
33cc946
ea9c0d3
00f3401
 
33cc946
3a82207
 
 
33cc946
3a82207
 
 
 
 
00f3401
3a82207
33cc946
63b82b4
 
34b43d3
63b82b4
 
 
 
 
b638764
63b82b4
00f3401
63b82b4
 
 
 
ea9c0d3
63b82b4
 
 
 
9a34670
63b82b4
b693a74
63b82b4
33cc946
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import gradio as gr
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
    LlamaTokenizer,
)
import os
from threading import Thread
import spaces
import subprocess

# flash-attn ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์„ค์น˜. CUDA ๋นŒ๋“œ๋Š” ๊ฑด๋„ˆ๋œ€.
subprocess.run(
    "pip install flash-attn --no-build-isolation",
    env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
    shell=True,
)

# Hugging Face ํ† ํฐ ๊ฐ€์ ธ์˜ค๊ธฐ
token = os.environ["HF_TOKEN"]

# apple/OpenELM-270M ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
# ํ† ํฌ๋‚˜์ด์ €๊ฐ€ ์˜ค๋ฅ˜๋‚˜๋Š” ๋ฌธ์ œ๊ฐ€ ์žˆ์–ด์„œ NousResearch/Llama-2-7b-hf๋ฅผ ์”€ 
# ํ•œ๊ตญ์–ด ๋ชจ๋ธ ํ† ํฌ๋‚˜์ด์ €๋กœ ๋ฐ”๊ฟ”๋ด„ beomi/llama-2-ko-7b
# apple/OpenELM-1.1B ํ† ํฌ๋‚˜์ด์ €๋งŒ ํฌ๊ฒŒ ํ•ด๋ด„ <- ์•ˆ๋จ
# apple/OpenELM-3B-Instruct๋กœ ๋‘˜๋‹ค ๋ณ€๊ฒฝ ํ•ด๋ด„ <- ์•ˆ๋จ
model = AutoModelForCausalLM.from_pretrained(
    "apple/OpenELM-270M-Instruct",
    token=token,
    trust_remote_code=True,
)
tok = AutoTokenizer.from_pretrained(
    "NousResearch/Llama-2-7b-hf",
    token=token,
    trust_remote_code=True,
    tokenizer_class=LlamaTokenizer,
)

# ์ข…๋ฃŒ ํ† ํฐ ID ์„ค์ •
terminators = [
    tok.eos_token_id,
]

# GPU๊ฐ€ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๊ฒฝ์šฐ GPU๋กœ, ์•„๋‹ˆ๋ฉด CPU๋กœ ๋ชจ๋ธ ๋กœ๋“œ
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(device)}")
else:
    device = torch.device("cpu")
    print("Using CPU")

model = model.to(device)

# Spaces์˜ GPU ์ž์›์„ ์‚ฌ์šฉํ•˜์—ฌ chat ํ•จ์ˆ˜ ์‹คํ–‰. ์ตœ๋Œ€ 60์ดˆ ๋™์•ˆ GPU ์ž์› ์‚ฌ์šฉ ๊ฐ€๋Šฅ.
@spaces.GPU(duration=60)
def chat(message, history, temperature, do_sample, max_tokens):
    # ์ฑ„ํŒ… ๊ธฐ๋ก์„ ์ ์ ˆํ•œ ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜
    chat = []
    for item in history:
        chat.append({"role": "user", "content": item[0]})
        if item[1] is not None:
            chat.append({"role": "assistant", "content": item[1]})
    chat.append({"role": "user", "content": message})
    
    # ํ† ํฌ๋‚˜์ด์ €๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ž…๋ ฅ ์ฒ˜๋ฆฌ
    messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
    model_inputs = tok([messages], return_tensors="pt").to(device)
    
    # TextIteratorStreamer๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ ์ถœ๋ ฅ ์ŠคํŠธ๋ฆฌ๋ฐ
    streamer = TextIteratorStreamer(
        tok, timeout=20.0, skip_prompt=True, skip_special_tokens=True
    )
    
    # ์ƒ์„ฑ ๊ด€๋ จ ๋งค๊ฐœ๋ณ€์ˆ˜ ์„ค์ •
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=max_tokens,  # ์ƒ์„ฑํ•  ์ตœ๋Œ€ ์ƒˆ ํ† ํฐ ์ˆ˜
        do_sample=True,  # ์ƒ˜ํ”Œ๋ง ์—ฌ๋ถ€
        temperature=temperature,  # ์˜จ๋„ ๋งค๊ฐœ๋ณ€์ˆ˜. ๋†’์„์ˆ˜๋ก ๋‹ค์–‘์„ฑ ์ฆ๊ฐ€
        eos_token_id=terminators,  # ์ข…๋ฃŒ ํ† ํฐ ID
    )

    # ์˜จ๋„๊ฐ€ 0์ด๋ฉด ์ƒ˜ํ”Œ๋งํ•˜์ง€ ์•Š์Œ
    if temperature == 0:
        generate_kwargs["do_sample"] = False

    # ๋ณ„๋„ ์Šค๋ ˆ๋“œ์—์„œ ๋ชจ๋ธ ์ƒ์„ฑ ์‹œ์ž‘
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    # ์ƒ์„ฑ๋œ ํ…์ŠคํŠธ๋ฅผ ๋ฐ˜๋ณต์ ์œผ๋กœ yield
    partial_text = ""
    for new_text in streamer:
        partial_text += new_text
        yield partial_text

    yield partial_text

# Gradio์˜ ChatInterface๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋Œ€ํ™”ํ˜• ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
demo = gr.ChatInterface(
    fn=chat,
    examples=[["let's talk about korea"]],
    additional_inputs_accordion=gr.Accordion(
        label="โš™๏ธ Parameters", open=False, render=False
    ),
    additional_inputs=[
        gr.Slider(
            minimum=0, maximum=1, step=0.1, value=0.7, label="Temperature", render=False
        ),
        gr.Checkbox(label="Sampling", value=True),
        gr.Slider(
            minimum=128,
            maximum=4096,
            step=1,
            value=512,
            label="Max new tokens",
            render=False,
        ),
    ],
    stop_btn="Stop Generation",
    title="Chat With LLMs",
    description="Now Running [apple/OpenELM-270M](https://huggingface.co/apple/OpenELM-270M)",
)

# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์‹คํ–‰
demo.launch()