File size: 3,103 Bytes
5fd0daa
154a87e
a51e5f4
 
 
5fd0daa
 
 
 
154a87e
a51e5f4
 
154a87e
5fd0daa
ed12022
128b55e
d14f87e
154a87e
a51e5f4
 
32006fa
 
 
a51e5f4
 
 
 
32006fa
a51e5f4
154a87e
a51e5f4
 
54b7145
a51e5f4
 
 
f6b2655
 
 
a51e5f4
 
 
154a87e
 
a51e5f4
 
 
 
 
32006fa
 
 
 
 
 
a51e5f4
 
 
 
 
154a87e
a51e5f4
154a87e
 
 
54b7145
a51e5f4
154a87e
 
 
e3ab0ad
32006fa
 
 
154a87e
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import torch
from string import Template
from huggingface_hub import login

# Hugging Face์— ๋กœ๊ทธ์ธ (ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ Access Token ๊ฐ€์ ธ์˜ค๊ธฐ)
login(os.getenv("ACCESS_TOKEN"))  # ACCESS_TOKEN์„ ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ ๋ถˆ๋Ÿฌ์˜ด

# ํ”„๋กฌํ”„ํŠธ ํ…œํ”Œ๋ฆฟ ์„ค์ •
prompt_template = Template("Human: ${inst} </s> Assistant: ")

# ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
model_name = "meta-llama/Llama-3.2-1b-instruct"  # ๋ชจ๋ธ ๊ฒฝ๋กœ
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="cpu").eval()

# ์ƒ์„ฑ ์„ค์ • (Gradio UI์—์„œ ์ œ์–ดํ•  ์ˆ˜ ์žˆ๋Š” ๋ณ€์ˆ˜๋“ค)
default_generation_config = GenerationConfig(
    temperature=0.1,
    top_k=30,
    top_p=0.5,
    do_sample=True,
    num_beams=1,
    repetition_penalty=1.1,
    min_new_tokens=10,
    max_new_tokens=30
)

# ์‘๋‹ต ์ƒ์„ฑ ํ•จ์ˆ˜
def respond(message, history, system_message, max_tokens, temperature, top_p):
    # ์ƒ์„ฑ ์„ค์ •
    generation_config = GenerationConfig(
        **default_generation_config.to_dict()  # ๊ธฐ๋ณธ ์„ค์ •๊ณผ ๋ณ‘ํ•ฉ
    )
    generation_config.max_new_tokens = max_tokens  # max_tokens ๋”ฐ๋กœ ์„ค์ •
    generation_config.temperature = temperature   # temperature ๋”ฐ๋กœ ์„ค์ •
    generation_config.top_p = top_p
    
    # ๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ์™€ ์‹œ์Šคํ…œ ๋ฉ”์‹œ์ง€๋ฅผ ํฌํ•จํ•œ ํ”„๋กฌํ”„ํŠธ ๊ตฌ์„ฑ
    prompt = prompt_template.safe_substitute({"inst": system_message})
    for val in history:
        if val[0]:
            prompt += f"Human: {val[0]} </s> Assistant: {val[1]} </s> "
    prompt += f"Human: {message} </s> Assistant: "
    
    # ๋ชจ๋ธ ์ž…๋ ฅ ์ƒ์„ฑ
    inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
    response_ids = model.generate(
        **inputs, 
        generation_config=generation_config, 
        eos_token_id=tokenizer.eos_token_id,  # ์ข…๋ฃŒ ํ† ํฐ ์„ค์ •
        pad_token_id=tokenizer.eos_token_id   # pad_token_id๋„ ์ข…๋ฃŒ ํ† ํฐ์œผ๋กœ ์„ค์ •
    )
    
    # ๋ชจ๋ธ ์‘๋‹ต ๋””์ฝ”๋”ฉ
    response_text = tokenizer.decode(response_ids[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    
    # ์‹ค์‹œ๊ฐ„ ์‘๋‹ต์„ ์œ„ํ•œ ๋ถ€๋ถ„์  ํ…์ŠคํŠธ ๋ฐ˜ํ™˜
    response = ""
    for token in response_text:
        response += token
        yield response


# Gradio Chat Interface ์„ค์ •
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a friendly and knowledgeable assistant who can discuss a wide range of topics related to music, including genres, artists, albums, instruments, and music history.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=30, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.1, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.05, label="Top-p (nucleus sampling)"),
    ],
)

if __name__ == "__main__":
    demo.launch()