File size: 3,412 Bytes
9443a16
 
 
 
 
 
0a56d56
4e2b6c5
e476f12
9443a16
 
e476f12
9443a16
 
 
 
 
 
 
 
 
bd0158c
9443a16
 
 
 
 
 
 
 
 
 
 
60bcc4a
9443a16
 
 
 
 
bd0158c
9443a16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3279b8c
9443a16
 
60bcc4a
9443a16
 
 
 
 
 
 
 
 
 
 
0a56d56
9443a16
bd0158c
9443a16
 
 
1aa87ff
bd0158c
 
9443a16
bd0158c
 
 
 
 
 
 
 
 
3279b8c
 
 
 
 
 
 
 
9443a16
 
 
 
 
 
bd0158c
9443a16
 
 
 
 
 
 
bd0158c
9443a16
 
 
 
 
 
 
bd0158c
9443a16
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import spaces
import gradio as gr
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

MODELS = ["Qwen/Qwen2-1.5B-Instruct", "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8"]
model = os.environ.get("MODEL_ID")
model_name = model.split("/")[-1]

DESCRIPTION = f"""
<h3>MODEL: <a href="https://hf.co/{model}">{model_name}</a></h3>
<center>
<p>Qwen is the large language model built by Alibaba Cloud.
<br>
Feel free to test without log.
</p>
</center>
"""

css="""
h3 {
    text-align: center;
}
footer {
    visibility: hidden;
}
"""


# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model)



@spaces.GPU
def generate(message, history, system, max_tokens, temperature, top_p, top_k, penalty):
    # Prepare your prompts
    conversation = [
        {"role": "system", "content":system}
    ]
    for prompt, answer in history:
        conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
    conversation.append({"role": "user", "content": message})

    
    text = tokenizer.apply_chat_template(
        conversation,
        tokenize=False,
        add_generation_prompt=True
    )
    sampling_params = SamplingParams(
        temperature=temperature, 
        top_p=top_p, 
        top_k=top_k,
        repetition_penalty=penalty, 
        max_tokens=max_tokens,
        stop_token_ids=[151645,151643],
    )
    # generate outputs
    llm = LLM(model=model)
    outputs = llm.generate([text], sampling_params)
    
    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
        return generated_text


    
chatbot = gr.Chatbot(height=800)

with gr.Blocks(css=css) as demo:
    gr.HTML(DESCRIPTION)
    gr.ChatInterface(
        fn=generate,
        chatbot=chatbot,
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
        additional_inputs=[
            gr.Textbox(value="You are a helpful assistant.", label="System message", render=False),
            gr.Slider(
                minimum=1, 
                maximum=30720, 
                value=2048, 
                step=1, 
                label="Max tokens",
                render=False,
            ),
            gr.Slider(
                minimum=0.1, 
                maximum=1.0, 
                value=0.7, 
                step=0.1, 
                label="Temperature",
                render=False,
            ),
            gr.Slider(
                minimum=0.1,
                maximum=1.0,
                value=0.95,
                step=0.05,
                label="Top-p",
                render=False,
            ),
            gr.Slider(
                minimum=0,
                maximum=20,
                value=20,
                step=1,
                label="Top-k",
                render=False,
            ),
            gr.Slider(
                minimum=0.0,
                maximum=2.0,
                value=1,
                step=0.1,
                label="Repetition penalty",
                render=False,
            ),
        ],
        retry_btn="Retry",
        undo_btn="Undo",
        clear_btn="Clear",
        submit_btn="Send",
    )

if __name__ == "__main__":
    demo.launch()