davidberenstein1957 commited on
Commit
8880b78
·
1 Parent(s): 9109509

add option to chosen and rejected

Browse files
Files changed (1) hide show
  1. app/app.py +84 -26
app/app.py CHANGED
@@ -9,6 +9,7 @@ from typing import Optional
9
 
10
  import gradio as gr
11
  from feedback import save_feedback
 
12
  from huggingface_hub import InferenceClient
13
  from pandas import DataFrame
14
 
@@ -85,14 +86,14 @@ def _process_content(content) -> str | list[str]:
85
  return content
86
 
87
 
88
- def add_fake_like_data(history: list, session_id: str) -> None:
89
- fake_like_data = {
90
  "index": len(history) - 1,
91
  "value": history[-1],
92
- "liked": False,
93
  }
94
  _, dataframe = wrangle_like_data(
95
- gr.LikeData(target=None, data=fake_like_data), history.copy()
96
  )
97
  submit_conversation(dataframe, session_id)
98
 
@@ -119,7 +120,15 @@ def respond_system_message(
119
 
120
  def update_dataframe(dataframe: DataFrame, history: list) -> DataFrame:
121
  """Update the dataframe with the new message"""
122
- return wrangle_like_data(gr.LikeData(target=None, data=history[-1]), history)
 
 
 
 
 
 
 
 
123
 
124
 
125
  def wrangle_like_data(x: gr.LikeData, history) -> DataFrame:
@@ -132,8 +141,13 @@ def wrangle_like_data(x: gr.LikeData, history) -> DataFrame:
132
 
133
  output_data = []
134
  for idx, message in enumerate(history):
 
 
 
135
  if idx == liked_index:
136
  message["metadata"] = {"title": "liked" if x.liked else "disliked"}
 
 
137
  rating = message["metadata"].get("title")
138
  if rating == "liked":
139
  message["rating"] = 1
@@ -142,6 +156,19 @@ def wrangle_like_data(x: gr.LikeData, history) -> DataFrame:
142
  else:
143
  message["rating"] = 0
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  output_data.append(
146
  dict(
147
  [(k, v) for k, v in message.items() if k not in ["metadata", "options"]]
@@ -151,7 +178,9 @@ def wrangle_like_data(x: gr.LikeData, history) -> DataFrame:
151
  return history, DataFrame(data=output_data)
152
 
153
 
154
- def wrangle_edit_data(x: gr.EditData, history: list, session_id: str) -> list:
 
 
155
  """Edit the conversation and add negative feedback if assistant message is edited, otherwise regenerate the message
156
 
157
  Return the history with the new message"""
@@ -160,40 +189,62 @@ def wrangle_edit_data(x: gr.EditData, history: list, session_id: str) -> list:
160
  else:
161
  index = x.index[0]
162
 
 
 
 
 
 
 
163
  if history[index]["role"] == "user":
164
- add_fake_like_data(history[: index + 2], session_id)
165
- return respond_system_message(
 
 
166
  history[: index + 1],
167
  temperature=random.randint(1, 100) / 100,
168
  seed=random.randint(0, 1000000),
169
  )
 
170
  else:
171
- # Add negative feedback on the original message
172
- add_fake_like_data(history[: index + 1], session_id)
173
- return history[: index + 1]
174
-
175
-
176
- def wrangle_undo_data(x: gr.UndoData, history: list, session_id: str) -> list:
 
 
 
 
 
 
 
 
 
177
  """Undo the last turn and add negative feedback on the original message
178
 
179
  Return the history without the last turn"""
180
  add_fake_like_data(history, session_id)
181
  # Return the history without the last turn
182
- return history[:-2]
 
183
 
184
 
185
- def wrangle_retry_data(x: gr.RetryData, history: list, session_id: str) -> list:
 
 
186
  """Respond to the user message with a system message and add negative feedback on the original message
187
 
188
  Return the history with the new message"""
189
  add_fake_like_data(history, session_id)
190
 
191
  # Return the history without a new message
192
- return respond_system_message(
193
  history[:-1],
194
  temperature=random.randint(1, 100) / 100,
195
  seed=random.randint(0, 1000000),
196
  )
 
197
 
198
 
199
  def submit_conversation(dataframe, session_id):
@@ -214,7 +265,14 @@ def submit_conversation(dataframe, session_id):
214
  return (gr.Dataframe(value=None, interactive=False), [])
215
 
216
 
217
- with gr.Blocks() as demo:
 
 
 
 
 
 
 
218
  ##############################
219
  # Chatbot
220
  ##############################
@@ -240,7 +298,7 @@ with gr.Blocks() as demo:
240
  submit_btn=True,
241
  )
242
 
243
- dataframe = gr.DataFrame()
244
 
245
  submit_btn = gr.Button(
246
  value="Submit conversation",
@@ -256,7 +314,7 @@ with gr.Blocks() as demo:
256
  outputs=[chatbot, chat_input],
257
  ).then(respond_system_message, chatbot, chatbot, api_name="bot_response").then(
258
  lambda: gr.Textbox(interactive=True), None, [chat_input]
259
- )
260
 
261
  chatbot.like(
262
  fn=wrangle_like_data,
@@ -267,20 +325,20 @@ with gr.Blocks() as demo:
267
 
268
  chatbot.retry(
269
  fn=wrangle_retry_data,
270
- inputs=[chatbot, session_id],
271
- outputs=[chatbot],
272
  )
273
 
274
  chatbot.edit(
275
  fn=wrangle_edit_data,
276
- inputs=[chatbot, session_id],
277
  outputs=[chatbot],
278
- )
279
 
280
  chatbot.undo(
281
  fn=wrangle_undo_data,
282
- inputs=[chatbot, session_id],
283
- outputs=[chatbot],
284
  )
285
 
286
  submit_btn.click(
 
9
 
10
  import gradio as gr
11
  from feedback import save_feedback
12
+ from gradio.components.chatbot import Option
13
  from huggingface_hub import InferenceClient
14
  from pandas import DataFrame
15
 
 
86
  return content
87
 
88
 
89
+ def add_fake_like_data(history: list, session_id: str, liked: bool = False) -> None:
90
+ data = {
91
  "index": len(history) - 1,
92
  "value": history[-1],
93
+ "liked": True,
94
  }
95
  _, dataframe = wrangle_like_data(
96
+ gr.LikeData(target=None, data=data), history.copy()
97
  )
98
  submit_conversation(dataframe, session_id)
99
 
 
120
 
121
  def update_dataframe(dataframe: DataFrame, history: list) -> DataFrame:
122
  """Update the dataframe with the new message"""
123
+ data = {
124
+ "index": 9999,
125
+ "value": None,
126
+ "liked": False,
127
+ }
128
+ _, dataframe = wrangle_like_data(
129
+ gr.LikeData(target=None, data=data), history.copy()
130
+ )
131
+ return dataframe
132
 
133
 
134
  def wrangle_like_data(x: gr.LikeData, history) -> DataFrame:
 
141
 
142
  output_data = []
143
  for idx, message in enumerate(history):
144
+ print(message)
145
+ if isinstance(message, gr.ChatMessage):
146
+ message = message.__dict__
147
  if idx == liked_index:
148
  message["metadata"] = {"title": "liked" if x.liked else "disliked"}
149
+ if not isinstance(message["metadata"], dict):
150
+ message["metadata"] = message["metadata"].__dict__
151
  rating = message["metadata"].get("title")
152
  if rating == "liked":
153
  message["rating"] = 1
 
156
  else:
157
  message["rating"] = 0
158
 
159
+ message["chosen"] = ""
160
+ message["rejected"] = ""
161
+ if message["options"]:
162
+ for option in message["options"]:
163
+ if not isinstance(option, dict):
164
+ option = option.__dict__
165
+ message[option["label"]] = option["value"]
166
+ else:
167
+ if message["rating"] == 1:
168
+ message["chosen"] = message["content"]
169
+ elif message["rating"] == -1:
170
+ message["rejected"] = message["content"]
171
+
172
  output_data.append(
173
  dict(
174
  [(k, v) for k, v in message.items() if k not in ["metadata", "options"]]
 
178
  return history, DataFrame(data=output_data)
179
 
180
 
181
+ def wrangle_edit_data(
182
+ x: gr.EditData, history: list, dataframe: DataFrame, session_id: str
183
+ ) -> list:
184
  """Edit the conversation and add negative feedback if assistant message is edited, otherwise regenerate the message
185
 
186
  Return the history with the new message"""
 
189
  else:
190
  index = x.index[0]
191
 
192
+ original_message = gr.ChatMessage(
193
+ role="assistant", content=dataframe.iloc[index]["content"]
194
+ ).__dict__
195
+ import pdb
196
+
197
+ pdb.set_trace()
198
  if history[index]["role"] == "user":
199
+ # Add feedback on original and corrected message
200
+ add_fake_like_data(history[: index + 2], session_id, liked=True)
201
+ add_fake_like_data(history[: index + 1] + [original_message], session_id)
202
+ history = respond_system_message(
203
  history[: index + 1],
204
  temperature=random.randint(1, 100) / 100,
205
  seed=random.randint(0, 1000000),
206
  )
207
+ return history
208
  else:
209
+ # Add feedback on original and corrected message
210
+ add_fake_like_data(history[: index + 1], session_id, liked=True)
211
+ add_fake_like_data(history[:index] + [original_message], session_id)
212
+ history = history[: index + 1]
213
+ # add chosen and rejected options
214
+ history[-1]["options"] = [
215
+ Option(label="chosen", value=x.value),
216
+ Option(label="rejected", value=original_message["content"]),
217
+ ]
218
+ return history
219
+
220
+
221
+ def wrangle_undo_data(
222
+ x: gr.UndoData, history: list, dataframe: DataFrame, session_id: str
223
+ ) -> list:
224
  """Undo the last turn and add negative feedback on the original message
225
 
226
  Return the history without the last turn"""
227
  add_fake_like_data(history, session_id)
228
  # Return the history without the last turn
229
+ history = history[:-2]
230
+ return history, update_dataframe(dataframe, history)
231
 
232
 
233
+ def wrangle_retry_data(
234
+ x: gr.RetryData, history: list, dataframe: DataFrame, session_id: str
235
+ ) -> list:
236
  """Respond to the user message with a system message and add negative feedback on the original message
237
 
238
  Return the history with the new message"""
239
  add_fake_like_data(history, session_id)
240
 
241
  # Return the history without a new message
242
+ history = respond_system_message(
243
  history[:-1],
244
  temperature=random.randint(1, 100) / 100,
245
  seed=random.randint(0, 1000000),
246
  )
247
+ return history, update_dataframe(dataframe, history)
248
 
249
 
250
  def submit_conversation(dataframe, session_id):
 
265
  return (gr.Dataframe(value=None, interactive=False), [])
266
 
267
 
268
+ css = """
269
+ .options {
270
+ display: none !important;
271
+ }
272
+ """
273
+
274
+
275
+ with gr.Blocks(css=css) as demo:
276
  ##############################
277
  # Chatbot
278
  ##############################
 
298
  submit_btn=True,
299
  )
300
 
301
+ dataframe = gr.Dataframe(wrap=True)
302
 
303
  submit_btn = gr.Button(
304
  value="Submit conversation",
 
314
  outputs=[chatbot, chat_input],
315
  ).then(respond_system_message, chatbot, chatbot, api_name="bot_response").then(
316
  lambda: gr.Textbox(interactive=True), None, [chat_input]
317
+ ).then(update_dataframe, inputs=[dataframe, chatbot], outputs=[dataframe])
318
 
319
  chatbot.like(
320
  fn=wrangle_like_data,
 
325
 
326
  chatbot.retry(
327
  fn=wrangle_retry_data,
328
+ inputs=[chatbot, dataframe, session_id],
329
+ outputs=[chatbot, dataframe],
330
  )
331
 
332
  chatbot.edit(
333
  fn=wrangle_edit_data,
334
+ inputs=[chatbot, dataframe, session_id],
335
  outputs=[chatbot],
336
+ ).then(update_dataframe, inputs=[dataframe, chatbot], outputs=[dataframe])
337
 
338
  chatbot.undo(
339
  fn=wrangle_undo_data,
340
+ inputs=[chatbot, dataframe, session_id],
341
+ outputs=[chatbot, dataframe],
342
  )
343
 
344
  submit_btn.click(