File size: 3,663 Bytes
5e818be
f51b330
5e818be
 
21bd07c
5e818be
 
24e1981
5e818be
 
 
 
21bd07c
5e818be
21bd07c
 
 
5e818be
 
 
f51b330
5e818be
21bd07c
 
 
 
 
24e1981
 
21bd07c
f51b330
5e818be
21bd07c
f51b330
5e818be
 
 
 
 
 
 
 
f51b330
5e818be
21bd07c
5e818be
 
21bd07c
5e818be
21bd07c
 
 
 
 
 
5e818be
 
21bd07c
 
 
5e818be
 
 
 
 
 
 
 
 
21bd07c
 
 
 
f51b330
 
 
5e818be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263e293
 
5e818be
1f06154
 
 
5e818be
 
 
21bd07c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import uuid

import gradio as gr

import torch
from transformers import AutoTokenizer
from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
MODEL_ID = "neuralmagic/OpenHermes-2.5-Mistral-7B-pruned50"

DESCRIPTION = f"""\
# NM vLLM Chat
Model: {MODEL_ID}
"""

if not torch.cuda.is_available():
    raise ValueError("Running on CPU 🥶 This demo does not work on CPU.")

engine_args = AsyncEngineArgs(
    model=MODEL_ID, 
    sparsity="sparse_w16a16", 
    max_model_len=MAX_INPUT_TOKEN_LENGTH
)
engine = AsyncLLMEngine.from_engine_args(engine_args)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.use_default_system_prompt = False


async def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    system_prompt: str,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
):
    conversation = []

    if system_prompt:
        conversation.append({"role": "system", "content": system_prompt})

    for user, assistant in chat_history:
        conversation.extend(
            [
                {"role": "user", "content": user},
                {"role": "assistant", "content": assistant},
            ]
        )
    conversation.append({"role": "user", "content": message})

    formatted_conversation = tokenizer.apply_chat_template(
        conversation, tokenize=False, add_generation_prompt=True
    )

    sampling_params = SamplingParams(
        max_tokens=max_new_tokens,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        repetition_penalty=repetition_penalty,
    )

    stream = await engine.add_request(
        uuid.uuid4().hex, formatted_conversation, sampling_params
    )

    async for request_output in stream:
        text = request_output.outputs[0].text
        yield text


chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Textbox(label="System prompt", lines=6),
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.6,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
    ],
    stop_btn=None,
    examples=[
        ["Hello there! How are you doing?"],
        ["Can you explain briefly to me what is the Python programming language?"],
        ["Explain the plot of Cinderella in a sentence."],
        ["How many hours does it take a man to eat a Helicopter?"],
        ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
    ],
)

# with gr.Blocks(css="style.css") as demo:
with gr.Blocks() as demo:
    gr.Markdown(DESCRIPTION)
    # gr.DuplicateButton(
    #     value="Duplicate Space for private use", elem_id="duplicate-button"
    # )
    chat_interface.render()

if __name__ == "__main__":
    demo.queue(max_size=20).launch()