File size: 2,645 Bytes
1c776f7
5ee0c02
2c4a7ea
dd48380
5ee0c02
dd48380
2c4a7ea
1c776f7
 
dd48380
1c776f7
 
 
 
 
2c4a7ea
 
 
 
 
 
 
 
 
 
51a7d0f
 
 
dd48380
5e66ec0
 
 
 
 
 
 
5ee0c02
5e66ec0
5ee0c02
 
51a7d0f
e14b97a
5ee0c02
5e69eac
51a7d0f
5ee0c02
 
 
 
1c776f7
dd48380
5e66ec0
 
1c776f7
5e66ec0
1c776f7
dd48380
1c776f7
dd48380
 
fbcf846
dd48380
 
2c4a7ea
dd48380
 
2c4a7ea
1c776f7
 
 
 
5ee0c02
 
1c776f7
dd48380
 
73bf78a
5e66ec0
 
 
 
 
d8e275b
 
5e66ec0
 
 
 
d8e275b
5e66ec0
 
 
 
5ee0c02
5e66ec0
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import time
import requests
import gradio as gr
from huggingface_hub import get_inference_endpoint

endpoint_name = os.getenv('ENDPOINT_NAME')
endpoint_url = os.getenv('ENDPOINT_URL')
personal_secret_token = os.getenv('PERSONAL_HF_TOKEN')

turn_breaker = os.getenv('TURN_BREAKER')
system_symbol = os.getenv('SYSTEM_SYMBOL')
user_symbol = os.getenv('USER_SYMBOL')
assistant_symbol = os.getenv('ASSISTANT_SYMBOL')

headers = {
	"Accept" : "application/json",
	"Authorization": f"Bearer {personal_secret_token}",
	"Content-Type": "application/json"
}

def query(payload):
	response = requests.post(endpoint_url, headers=headers, json=payload)
	return response.json()

def get_status():
    endpoint = get_inference_endpoint(endpoint_name, token=personal_secret_token)
    return endpoint.status

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_new_tokens,
    temperature,
    top_p,
    progress=gr.Progress()
):
    progress(0, desc="Starting")

    if get_status() != "running":
        query({"inputs": "wake up!"})
        progress(0.25, desc="Waking up model")
        
    while get_status() != "running":
        time.sleep(1)

    progress(0.5, desc="Generating")
    
    all_messages = [system_message]

    for val in history:
        if val[0]:
            all_messages.append(user_symbol+val[0])
        if val[1]:
            all_messages.append(assistant_symbol+val[1])

    all_messages.append(user_symbol+message)

    generation_kwargs = dict(
        max_new_tokens=max_new_tokens,
        do_sample=temperature > 0,
        top_p=top_p,
        temperature=temperature
    )

    response = query({
        "inputs": turn_breaker.join(all_messages),
        "parameters": generation_kwargs
    })

    progress(1, desc="Generating")

    return response


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="请你扮演一个开心,积极的角色,名叫贺英旭,今年26岁,工作是程序员。你需要以这个身份和朋友们进行对话。", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=64, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.7,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
    show_progress="full"
)


if __name__ == "__main__":
    demo.launch()