File size: 2,645 Bytes
1c776f7 5ee0c02 2c4a7ea dd48380 5ee0c02 dd48380 2c4a7ea 1c776f7 dd48380 1c776f7 2c4a7ea 51a7d0f dd48380 5e66ec0 5ee0c02 5e66ec0 5ee0c02 51a7d0f e14b97a 5ee0c02 5e69eac 51a7d0f 5ee0c02 1c776f7 dd48380 5e66ec0 1c776f7 5e66ec0 1c776f7 dd48380 1c776f7 dd48380 fbcf846 dd48380 2c4a7ea dd48380 2c4a7ea 1c776f7 5ee0c02 1c776f7 dd48380 73bf78a 5e66ec0 d8e275b 5e66ec0 d8e275b 5e66ec0 5ee0c02 5e66ec0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import os
import time
import requests
import gradio as gr
from huggingface_hub import get_inference_endpoint
endpoint_name = os.getenv('ENDPOINT_NAME')
endpoint_url = os.getenv('ENDPOINT_URL')
personal_secret_token = os.getenv('PERSONAL_HF_TOKEN')
turn_breaker = os.getenv('TURN_BREAKER')
system_symbol = os.getenv('SYSTEM_SYMBOL')
user_symbol = os.getenv('USER_SYMBOL')
assistant_symbol = os.getenv('ASSISTANT_SYMBOL')
headers = {
"Accept" : "application/json",
"Authorization": f"Bearer {personal_secret_token}",
"Content-Type": "application/json"
}
def query(payload):
response = requests.post(endpoint_url, headers=headers, json=payload)
return response.json()
def get_status():
endpoint = get_inference_endpoint(endpoint_name, token=personal_secret_token)
return endpoint.status
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_new_tokens,
temperature,
top_p,
progress=gr.Progress()
):
progress(0, desc="Starting")
if get_status() != "running":
query({"inputs": "wake up!"})
progress(0.25, desc="Waking up model")
while get_status() != "running":
time.sleep(1)
progress(0.5, desc="Generating")
all_messages = [system_message]
for val in history:
if val[0]:
all_messages.append(user_symbol+val[0])
if val[1]:
all_messages.append(assistant_symbol+val[1])
all_messages.append(user_symbol+message)
generation_kwargs = dict(
max_new_tokens=max_new_tokens,
do_sample=temperature > 0,
top_p=top_p,
temperature=temperature
)
response = query({
"inputs": turn_breaker.join(all_messages),
"parameters": generation_kwargs
})
progress(1, desc="Generating")
return response
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="请你扮演一个开心,积极的角色,名叫贺英旭,今年26岁,工作是程序员。你需要以这个身份和朋友们进行对话。", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=64, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
show_progress="full"
)
if __name__ == "__main__":
demo.launch() |