yuntian-deng commited on
Commit
34ab564
1 Parent(s): 56d3094

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -14
app.py CHANGED
@@ -17,6 +17,20 @@ def exception_handler(exception_type, exception, traceback):
17
  sys.excepthook = exception_handler
18
  sys.tracebacklimit = 0
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def predict(inputs, top_p, temperature, chat_counter, chatbot=[], history=[]):
21
 
22
  payload = {
@@ -37,24 +51,24 @@ def predict(inputs, top_p, temperature, chat_counter, chatbot=[], history=[]):
37
 
38
  # print(f"chat_counter - {chat_counter}")
39
  if chat_counter != 0 :
40
- messages=[]
41
- for data in chatbot:
42
- temp1 = {}
43
- temp1["role"] = "user"
44
- temp1["content"] = data[0]
45
- temp2 = {}
46
- temp2["role"] = "assistant"
47
- temp2["content"] = data[1]
48
- messages.append(temp1)
49
- messages.append(temp2)
 
50
  temp3 = {}
51
  temp3["role"] = "user"
52
  temp3["content"] = inputs
53
  messages.append(temp3)
54
- #messages
55
  payload = {
56
  "model": "gpt-4",
57
- "messages": messages, #[{"role": "user", "content": f"{inputs}"}],
58
  "temperature" : temperature, #1.0,
59
  "top_p": top_p, #1.0,
60
  "n" : 1,
@@ -66,7 +80,6 @@ def predict(inputs, top_p, temperature, chat_counter, chatbot=[], history=[]):
66
  chat_counter+=1
67
 
68
  history.append(inputs)
69
- # print(f"payload is - {payload}")
70
  # make a POST request to the API endpoint using the requests.post method, passing in stream=True
71
  response = requests.post(API_URL, headers=headers, json=payload, stream=True)
72
  response_code = f"{response}"
@@ -94,7 +107,7 @@ def predict(inputs, top_p, temperature, chat_counter, chatbot=[], history=[]):
94
  history.append(" " + partial_words)
95
  else:
96
  history[-1] = partial_words
97
- chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2) ] # convert to tuples of list
98
  token_counter+=1
99
  yield chat, history, chat_counter, response # resembles {chatbot: chat, state: history}
100
  print(json.dumps({"chat_counter": chat_counter, "payload": payload, "partial_words": partial_words, "token_counter": token_counter, "counter": counter}))
 
17
  sys.excepthook = exception_handler
18
  sys.tracebacklimit = 0
19
 
20
+ #https://github.com/gradio-app/gradio/issues/3531#issuecomment-1484029099
21
+ def parse_codeblock(text):
22
+ lines = text.split("\n")
23
+ for i, line in enumerate(lines):
24
+ if "```" in line:
25
+ if line != "```":
26
+ lines[i] = f'<pre><code class="{lines[i][3:]}">'
27
+ else:
28
+ lines[i] = '</code></pre>'
29
+ else:
30
+ if i > 0:
31
+ lines[i] = "<br/>" + line.replace("<", "&lt;").replace(">", "&gt;")
32
+ return "".join(lines)
33
+
34
  def predict(inputs, top_p, temperature, chat_counter, chatbot=[], history=[]):
35
 
36
  payload = {
 
51
 
52
  # print(f"chat_counter - {chat_counter}")
53
  if chat_counter != 0 :
54
+ messages = []
55
+ for i, data in enumerate(history):
56
+ if i % 2 == 0:
57
+ role = 'user'
58
+ else:
59
+ role = 'assistant'
60
+ temp = {}
61
+ temp["role"] = role
62
+ temp["content"] = data
63
+ messages.append(temp)
64
+
65
  temp3 = {}
66
  temp3["role"] = "user"
67
  temp3["content"] = inputs
68
  messages.append(temp3)
 
69
  payload = {
70
  "model": "gpt-4",
71
+ "messages": messages,
72
  "temperature" : temperature, #1.0,
73
  "top_p": top_p, #1.0,
74
  "n" : 1,
 
80
  chat_counter+=1
81
 
82
  history.append(inputs)
 
83
  # make a POST request to the API endpoint using the requests.post method, passing in stream=True
84
  response = requests.post(API_URL, headers=headers, json=payload, stream=True)
85
  response_code = f"{response}"
 
107
  history.append(" " + partial_words)
108
  else:
109
  history[-1] = partial_words
110
+ chat = [(parse_codeblock(history[i]), parse_codeblock(history[i + 1])) for i in range(0, len(history) - 1, 2) ] # convert to tuples of list
111
  token_counter+=1
112
  yield chat, history, chat_counter, response # resembles {chatbot: chat, state: history}
113
  print(json.dumps({"chat_counter": chat_counter, "payload": payload, "partial_words": partial_words, "token_counter": token_counter, "counter": counter}))