Yingxu He commited on
Commit
5ee0c02
1 Parent(s): cdaf751

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -13
app.py CHANGED
@@ -1,8 +1,7 @@
1
  import os
2
- import urllib
3
  import gradio as gr
4
-
5
- import requests
6
 
7
  endpoint_url = os.getenv('ENDPOINT_URL')
8
  personal_secret_token = os.getenv('PERSONAL_HF_TOKEN')
@@ -12,15 +11,7 @@ system_symbol = os.getenv('SYSTEM_SYMBOL')
12
  user_symbol = os.getenv('USER_SYMBOL')
13
  assistant_symbol = os.getenv('ASSISTANT_SYMBOL')
14
 
15
- headers = {
16
- "Accept" : "application/json",
17
- "Authorization": f"Bearer {personal_secret_token}",
18
- "Content-Type": "application/json"
19
- }
20
-
21
- def query(payload):
22
- response = requests.post(endpoint_url, headers=headers, json=payload)
23
- return response.json()
24
 
25
  def respond(
26
  message,
@@ -29,7 +20,16 @@ def respond(
29
  max_new_tokens,
30
  temperature,
31
  top_p,
 
32
  ):
 
 
 
 
 
 
 
 
33
  all_messages = [system_message]
34
 
35
  for val in history:
@@ -48,11 +48,13 @@ def respond(
48
  # stream=True,
49
  )
50
 
51
- response = query({
52
  "inputs": turn_breaker.join(all_messages),
53
  "parameters": generation_kwargs
54
  })
55
 
 
 
56
  return response
57
 
58
 
@@ -73,6 +75,7 @@ demo = gr.ChatInterface(
73
  label="Top-p (nucleus sampling)",
74
  ),
75
  ],
 
76
  )
77
 
78
 
 
1
  import os
2
+ import time
3
  import gradio as gr
4
+ from huggingface_hub import get_inference_endpoint
 
5
 
6
  endpoint_url = os.getenv('ENDPOINT_URL')
7
  personal_secret_token = os.getenv('PERSONAL_HF_TOKEN')
 
11
  user_symbol = os.getenv('USER_SYMBOL')
12
  assistant_symbol = os.getenv('ASSISTANT_SYMBOL')
13
 
14
+ endpoint = get_inference_endpoint(endpoint_url, token=personal_secret_token)
 
 
 
 
 
 
 
 
15
 
16
  def respond(
17
  message,
 
20
  max_new_tokens,
21
  temperature,
22
  top_p,
23
+ progress=gr.Progress()
24
  ):
25
+ progress(0, desc="Starting")
26
+
27
+ while endpoint.status != "running":
28
+ progress(0.25, desc="Waking up model")
29
+ time.sleep(1)
30
+
31
+ progress(0.5, desc="Generating")
32
+
33
  all_messages = [system_message]
34
 
35
  for val in history:
 
48
  # stream=True,
49
  )
50
 
51
+ response = endpoint.client.post({
52
  "inputs": turn_breaker.join(all_messages),
53
  "parameters": generation_kwargs
54
  })
55
 
56
+ progress(1, desc="Generating")
57
+
58
  return response
59
 
60
 
 
75
  label="Top-p (nucleus sampling)",
76
  ),
77
  ],
78
+ show_progress="full"
79
  )
80
 
81