File size: 2,381 Bytes
31a1ff8
e2b5fc2
a7706d8
54fe16b
a7706d8
 
54fe16b
 
31a1ff8
54fe16b
a7706d8
54fe16b
a7706d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53b40bf
54fe16b
 
a7706d8
 
 
54fe16b
 
a7706d8
 
54fe16b
a7706d8
54fe16b
 
 
 
 
a7706d8
 
54fe16b
 
 
a7706d8
 
 
 
 
53b40bf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import spaces

import os
import json
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer


@spaces.GPU()
def predict(message, history, system_prompt, temperature, max_tokens):
    messages = [{"role": "system", "content": system_prompt}]
    for human, assistant in history:
        messages.append({"role": "user", "content": human})
        messages.append({"role": "assistant", "content": assistant})
    messages.append({"role": "user", "content": message})
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    stop_tokens = ["<|im_end|>", "<|endoftext|>", "<|im_start|>"]
    sampling_params = SamplingParams(temperature=temperature, top_p=1, max_tokens=max_tokens, stop=stop_tokens)
    completions = llm.generate(prompt, sampling_params)
    for output in completions:
        prompt = output.prompt
        print('==========================question=============================')
        print(prompt)
        generated_text = output.outputs[0].text
        print('===========================answer=============================')
        print(generated_text)
        for idx in range(len(generated_text)):
                yield generated_text[:idx+1]


if __name__ == "__main__":
    path = "stabilityai/stablelm-2-12b-chat"
    tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
    llm = LLM(model=path, tensor_parallel_size=1, trust_remote_code=True)
    gr.ChatInterface(
        predict,
        title="LLM playground",
        description="This is a LLM playground for StableLM",
        theme="soft",
        chatbot=gr.Chatbot(height=1400, label="Chat History",),
        textbox=gr.Textbox(placeholder="input", container=False, scale=7),
        retry_btn=None,
        undo_btn="Delete Previous",
        clear_btn="Clear",
        additional_inputs=[
            gr.Textbox("You are a hepful assistant.", label="System Prompt"),
            gr.Slider(0, 1, 0.7, label="Temperature"),
            gr.Slider(100, 2048, 1024, label="Max Tokens"),
        ],
        additional_inputs_accordion_name="Parameters",
        examples=[
            ["implement snake game using pygame"],
            ["Can you explain briefly to me what is the Python programming language?"],
            ["write a program to find the factorial of a number"],
        ],
    ).queue().launch()