davidberenstein1957 commited on
Commit
104bc54
·
1 Parent(s): 8dff124

add retry edit and undo actions

Browse files
Files changed (1) hide show
  1. app/app.py +50 -7
app/app.py CHANGED
@@ -6,11 +6,10 @@ from mimetypes import guess_type
6
  from pathlib import Path
7
 
8
  import gradio as gr
 
9
  from huggingface_hub import InferenceClient
10
  from pandas import DataFrame
11
 
12
- from feedback import save_feedback
13
-
14
  client = InferenceClient(
15
  token=os.getenv("HF_TOKEN"),
16
  model=(
@@ -30,7 +29,7 @@ def add_user_message(history, message):
30
  return history, gr.MultimodalTextbox(value=None, interactive=False)
31
 
32
 
33
- def _format_history_as_messages(history: list):
34
  messages = []
35
  current_role = None
36
  current_message_content = []
@@ -84,17 +83,25 @@ def _process_content(content) -> str | list[str]:
84
  return content
85
 
86
 
 
 
 
 
 
 
 
 
 
 
87
  def respond_system_message(history: list) -> list: # -> list:
88
  """Respond to the user message with a system message"""
89
- messages = _format_history_as_messages(history)
90
  response = client.chat.completions.create(
91
  messages=messages,
92
  max_tokens=2000,
93
  stream=False,
94
  )
95
  content = response.choices[0].message.content
96
- # TODO: Add a response to the user message
97
-
98
  message = gr.ChatMessage(role="assistant", content=content)
99
  history.append(message)
100
  return history
@@ -103,7 +110,10 @@ def respond_system_message(history: list) -> list: # -> list:
103
  def wrangle_like_data(x: gr.LikeData, history) -> DataFrame:
104
  """Wrangle conversations and liked data into a DataFrame"""
105
 
106
- liked_index = x.index[0]
 
 
 
107
 
108
  output_data = []
109
  for idx, message in enumerate(history):
@@ -124,6 +134,19 @@ def wrangle_like_data(x: gr.LikeData, history) -> DataFrame:
124
  return history, DataFrame(data=output_data)
125
 
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  def submit_conversation(dataframe, session_id):
128
  """ "Submit the conversation to dataset repo"""
129
  if dataframe.empty:
@@ -154,8 +177,10 @@ with gr.Blocks() as demo:
154
 
155
  chatbot = gr.Chatbot(
156
  elem_id="chatbot",
 
157
  bubble_full_width=False,
158
  type="messages",
 
159
  )
160
 
161
  chat_input = gr.MultimodalTextbox(
@@ -189,6 +214,24 @@ with gr.Blocks() as demo:
189
  like_user_message=False,
190
  )
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  gr.Button(
193
  value="Submit conversation",
194
  ).click(
 
6
  from pathlib import Path
7
 
8
  import gradio as gr
9
+ from feedback import save_feedback
10
  from huggingface_hub import InferenceClient
11
  from pandas import DataFrame
12
 
 
 
13
  client = InferenceClient(
14
  token=os.getenv("HF_TOKEN"),
15
  model=(
 
29
  return history, gr.MultimodalTextbox(value=None, interactive=False)
30
 
31
 
32
+ def format_history_as_messages(history: list):
33
  messages = []
34
  current_role = None
35
  current_message_content = []
 
83
  return content
84
 
85
 
86
+ def remove_last_message(history: list) -> list:
87
+ return history[:-1]
88
+
89
+
90
+ def retry_respond_system_message(history: list) -> list:
91
+ """Respond to the user message with a system message"""
92
+ history = remove_last_message(history)
93
+ return respond_system_message(history)
94
+
95
+
96
  def respond_system_message(history: list) -> list: # -> list:
97
  """Respond to the user message with a system message"""
98
+ messages = format_history_as_messages(history)
99
  response = client.chat.completions.create(
100
  messages=messages,
101
  max_tokens=2000,
102
  stream=False,
103
  )
104
  content = response.choices[0].message.content
 
 
105
  message = gr.ChatMessage(role="assistant", content=content)
106
  history.append(message)
107
  return history
 
110
  def wrangle_like_data(x: gr.LikeData, history) -> DataFrame:
111
  """Wrangle conversations and liked data into a DataFrame"""
112
 
113
+ if isinstance(x.index, int):
114
+ liked_index = x.index
115
+ else:
116
+ liked_index = x.index[0]
117
 
118
  output_data = []
119
  for idx, message in enumerate(history):
 
134
  return history, DataFrame(data=output_data)
135
 
136
 
137
+ def wrangle_edit_data(x: gr.EditData, history: list) -> list:
138
+ if isinstance(x.index, int):
139
+ index = x.index
140
+ else:
141
+ index = x.index[0]
142
+
143
+ if history[index]["role"] == "user":
144
+ history = history[:index]
145
+ return respond_system_message(history)
146
+ else:
147
+ return history
148
+
149
+
150
  def submit_conversation(dataframe, session_id):
151
  """ "Submit the conversation to dataset repo"""
152
  if dataframe.empty:
 
177
 
178
  chatbot = gr.Chatbot(
179
  elem_id="chatbot",
180
+ editable="all",
181
  bubble_full_width=False,
182
  type="messages",
183
+ feedback_options=["Like", "Dislike"],
184
  )
185
 
186
  chat_input = gr.MultimodalTextbox(
 
214
  like_user_message=False,
215
  )
216
 
217
+ chatbot.retry(
218
+ fn=retry_respond_system_message,
219
+ inputs=[chatbot],
220
+ outputs=[chatbot],
221
+ )
222
+
223
+ chatbot.edit(
224
+ fn=wrangle_edit_data,
225
+ inputs=[chatbot],
226
+ outputs=[chatbot],
227
+ )
228
+
229
+ chatbot.undo(
230
+ fn=remove_last_message,
231
+ inputs=[chatbot],
232
+ outputs=[chatbot],
233
+ )
234
+
235
  gr.Button(
236
  value="Submit conversation",
237
  ).click(