ysharma HF staff commited on
Commit
8f4c543
1 Parent(s): 36b3cb9

added batch chatbot in new tab

Browse files
Files changed (1) hide show
  1. app.py +43 -2
app.py CHANGED
@@ -4,7 +4,8 @@ import os
4
  import requests
5
 
6
  hf_token = os.getenv('HF_TOKEN')
7
- api_url = os.getenv('API_URL')
 
8
  headers = {
9
  'Content-Type': 'application/json',
10
  }
@@ -82,4 +83,44 @@ def predict(message, chatbot):
82
  gr.Warning(f"KeyError: {e} occurred for JSON object: {json_obj}")
83
  continue
84
 
85
- gr.ChatInterface(predict, title=title, description=description, css=css, examples=examples, cache_examples=True).queue(concurrency_count=75).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import requests
5
 
6
  hf_token = os.getenv('HF_TOKEN')
7
+ api_url = os.getenv('API_URL')
8
+ api_url_nostream = os.getenv('API_URL_NOSTREAM')
9
  headers = {
10
  'Content-Type': 'application/json',
11
  }
 
83
  gr.Warning(f"KeyError: {e} occurred for JSON object: {json_obj}")
84
  continue
85
 
86
+
87
+ def predict_batch(message, chatbot):
88
+
89
+ input_prompt = f"[INST]<<SYS>>\n{system_message}\n<</SYS>>\n\n "
90
+ for interaction in chatbot:
91
+ input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s> [INST] "
92
+
93
+ input_prompt = input_prompt + str(message) + " [/INST] "
94
+
95
+ data = {
96
+ "inputs": input_prompt,
97
+ "parameters": {"max_new_tokens":256}
98
+ }
99
+
100
+ response = requests.post(api_url_nostream, headers=headers, data=json.dumps(data), auth=('hf', hf_token))
101
+
102
+ if response.status_code == 200: # check if the request was successful
103
+ try:
104
+ json_obj = response.json()
105
+ if 'generated_text' in json_obj and len(json_obj['generated_text']) > 0:
106
+ return json_obj['generated_text']
107
+ elif 'error' in json_obj:
108
+ return json_obj['error'] + ' Please refresh and try again with smaller input prompt'
109
+ else:
110
+ print(f"Unexpected response: {json_obj}")
111
+ except json.JSONDecodeError:
112
+ print(f"Failed to decode response as JSON: {response.text}")
113
+ else:
114
+ print(f"Request failed with status code {response.status_code}")
115
+
116
+
117
+ # Gradio Demo
118
+ with gr.Blocks() as demo:
119
+
120
+ with gr.Tab("Streaming"):
121
+ gr.ChatInterface(predict, title=title, description=description, css=css, examples=examples, cache_examples=True)
122
+
123
+ with gr.Tab("Batch"):
124
+ gr.ChatInterface(predict_batch, title=title, description=description, css=css, examples=examples, cache_examples=True)
125
+
126
+ demo.queue(concurrency_count=75).launch(debug=True)