Wonderplex commited on
Commit
4128c07
1 Parent(s): 4f8bd37

added format promp and changed sotopia information accordion (#35)

Browse files
Files changed (2) hide show
  1. app.py +27 -41
  2. utils.py +29 -1
app.py CHANGED
@@ -12,7 +12,7 @@ from transformers import (
12
  BitsAndBytesConfig,
13
  )
14
 
15
- from utils import Agent, format_sotopia_prompt, get_starter_prompt
16
  from functools import cache
17
 
18
  DEPLOYED = os.getenv("DEPLOYED", "true").lower() == "true"
@@ -84,7 +84,7 @@ def introduction():
84
 
85
 
86
  def param_accordion(according_visible=True):
87
- with gr.Accordion("Parameters", open=False, visible=according_visible):
88
  model_name = gr.Dropdown(
89
  choices=["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "mistralai/Mistral-7B-Instruct-v0.1", "GPT3.5"], # Example model choices
90
  value="cmu-lti/sotopia-pi-mistral-7b-BC_SR", # Default value
@@ -116,45 +116,31 @@ def param_accordion(according_visible=True):
116
  return temperature, session_id, max_tokens, model_name
117
 
118
 
119
- def sotopia_info_accordion(
120
- human_agent, machine_agent, scenario, according_visible=True
121
- ):
122
- with gr.Accordion(
123
- "Sotopia Information", open=False, visible=according_visible
124
- ):
125
  with gr.Row():
126
- with gr.Column():
127
- user_name = gr.Textbox(
128
- lines=1,
129
- label="username",
130
- value=human_agent.name,
131
- interactive=True,
132
- placeholder=f"{human_agent.name}: ",
133
- show_label=False,
134
- max_lines=1,
135
- )
136
- with gr.Column():
137
- bot_name = gr.Textbox(
138
- lines=1,
139
- value=machine_agent.name,
140
- interactive=True,
141
- placeholder=f"{machine_agent.name}: ",
142
- show_label=False,
143
- max_lines=1,
144
- visible=False,
145
- )
146
- with gr.Column():
147
- scenario = gr.Textbox(
148
- lines=4,
149
- value=scenario,
150
- interactive=False,
151
- placeholder="Scenario",
152
- show_label=False,
153
- max_lines=4,
154
- visible=False,
155
- )
156
- return user_name, bot_name, scenario
157
-
158
 
159
  def instructions_accordion(instructions, according_visible=False):
160
  with gr.Accordion("Instructions", open=False, visible=according_visible):
@@ -206,7 +192,7 @@ def chat_tab():
206
  text_output = tokenizer.decode(
207
  output_tokens[0], skip_special_tokens=True
208
  )
209
- return text_output
210
 
211
  with gr.Column():
212
  with gr.Row():
 
12
  BitsAndBytesConfig,
13
  )
14
 
15
+ from utils import Agent, format_sotopia_prompt, get_starter_prompt, format_bot_message
16
  from functools import cache
17
 
18
  DEPLOYED = os.getenv("DEPLOYED", "true").lower() == "true"
 
84
 
85
 
86
  def param_accordion(according_visible=True):
87
+ with gr.Accordion("Parameters", open=True, visible=according_visible):
88
  model_name = gr.Dropdown(
89
  choices=["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "mistralai/Mistral-7B-Instruct-v0.1", "GPT3.5"], # Example model choices
90
  value="cmu-lti/sotopia-pi-mistral-7b-BC_SR", # Default value
 
116
  return temperature, session_id, max_tokens, model_name
117
 
118
 
119
+ def sotopia_info_accordion(human_agent, machine_agent, scenario, accordion_visible=True):
120
+ with gr.Accordion("Sotopia Information", open=accordion_visible):
 
 
 
 
121
  with gr.Row():
122
+ user_name = gr.Textbox(
123
+ lines=1,
124
+ label="Human Agent Name",
125
+ value=human_agent.name,
126
+ interactive=True,
127
+ placeholder="Enter human agent name",
128
+ )
129
+ bot_name = gr.Textbox(
130
+ lines=1,
131
+ label="Machine Agent Name",
132
+ value=machine_agent.name,
133
+ interactive=True,
134
+ placeholder="Enter machine agent name",
135
+ )
136
+ scenario_textbox = gr.Textbox(
137
+ lines=4,
138
+ label="Scenario Description",
139
+ value=scenario,
140
+ interactive=True,
141
+ placeholder="Enter scenario description",
142
+ )
143
+ return user_name, bot_name, scenario_textbox
 
 
 
 
 
 
 
 
 
 
144
 
145
  def instructions_accordion(instructions, according_visible=False):
146
  with gr.Accordion("Instructions", open=False, visible=according_visible):
 
192
  text_output = tokenizer.decode(
193
  output_tokens[0], skip_special_tokens=True
194
  )
195
+ return format_bot_message(text_output)
196
 
197
  with gr.Column():
198
  with gr.Row():
utils.py CHANGED
@@ -1,4 +1,16 @@
1
  from typing import List, Tuple
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  class Agent:
4
  def __init__(self, name, background, goal, secrets, personality):
@@ -36,6 +48,20 @@ def truncate_dialogue_history_to_length(dia_his, surpass_num, tokenizer):
36
  return trunc_dia
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def dialogue_history_creation(history, user_name, bot_name):
40
  dialogue_history = ""
41
  for idx, turn in enumerate(history):
@@ -43,6 +69,8 @@ def dialogue_history_creation(history, user_name, bot_name):
43
  # TODOTODO (haofeiyu): we first assume that human talks first
44
  user_turn_idx = idx * 2
45
  bot_turn_idx = idx * 2 + 1
 
 
46
  dialogue_history = f"{dialogue_history}\n\nTurn #{user_turn_idx}: {user_name}: {user_message}\n\nTurn #{bot_turn_idx}: {bot_name}: {bot_message}"
47
  last_turn_idx = len(history) * 2
48
  return dialogue_history, last_turn_idx
@@ -75,4 +103,4 @@ def format_sotopia_prompt(
75
  )
76
  prompt = f"{prompt}\n{dialogue_history}"
77
  prompt = f"{prompt}\n\nTurn #{last_turn_idx+1}: {user_name}: {message}\n.\nYou are at Turn #{last_turn_idx+2}."
78
- return prompt
 
1
  from typing import List, Tuple
2
+ import ast
3
+
4
+ FORMAT_TEMPLATE = """ Your available action types are
5
+ "none action speak non-verbal communication leave".
6
+ Note: You can "leave" this conversation if 1. you have achieved your social goals, 2. this conversation makes you uncomfortable, 3. you find it uninteresting/you lose your patience, 4. or for other reasons you want to leave.
7
+
8
+ Please only generate a JSON string including the action type and the argument.
9
+ Your action should follow the given format:
10
+ \nAs an example, for the schema {\"properties\": {\"foo\": {\"title\": \"Foo\", \"description\": \"a list of strings\", \"type\": \"array\", \"items\": {\"type\": \"string\"}}}, \"required\": [\"foo\"]}
11
+ the object {\"foo\": [\"bar\", \"baz\"]} is a well-formatted instance of the schema. The object {\"properties\": {\"foo\": [\"bar\", \"baz\"]}} is not well-formatted.
12
+ \nHere is the output schema:\n```\n{\"description\": \"An interface for messages.\\nThere is only one required method: to_natural_language\", \"properties\": {\"action_type\": {\"title\": \"Action Type\", \"description\": \"whether to speak at this turn or choose to not do anything\", \"enum\": [\"none\", \"speak\", \"non-verbal communication\", \"action\", \"leave\"], \"type\": \"string\"}, \"argument\": {\"title\": \"Argument\", \"description\": \"the utterance if choose to speak, the expression or gesture if choose non-verbal communication, or the physical action if choose action\", \"type\": \"string\"}}, \"required\": [\"action_type\", \"argument\"]}\n```\u001b[0m
13
+ """
14
 
15
  class Agent:
16
  def __init__(self, name, background, goal, secrets, personality):
 
48
  return trunc_dia
49
 
50
 
51
+ def format_bot_message(bot_message) -> str:
52
+ json_response = ast.literal_eval(bot_message)
53
+ match json_response["action_type"]:
54
+ case "none":
55
+ return 'did nothing'
56
+ case "speak":
57
+ return json_response["argument"]
58
+ case "non-verbal communication":
59
+ return f'[{json_response["action_type"]}] {json_response["argument"]}'
60
+ case "action":
61
+ return f'[{json_response["action_type"]}] {json_response["argument"]}'
62
+ case "leave":
63
+ return 'left the conversation'
64
+
65
  def dialogue_history_creation(history, user_name, bot_name):
66
  dialogue_history = ""
67
  for idx, turn in enumerate(history):
 
69
  # TODOTODO (haofeiyu): we first assume that human talks first
70
  user_turn_idx = idx * 2
71
  bot_turn_idx = idx * 2 + 1
72
+ if not bot_message.startswith("["): # if action type == speak, need to add 'said: ' to be consistent with the dialog prompt
73
+ bot_message = "said :" + bot_message
74
  dialogue_history = f"{dialogue_history}\n\nTurn #{user_turn_idx}: {user_name}: {user_message}\n\nTurn #{bot_turn_idx}: {bot_name}: {bot_message}"
75
  last_turn_idx = len(history) * 2
76
  return dialogue_history, last_turn_idx
 
103
  )
104
  prompt = f"{prompt}\n{dialogue_history}"
105
  prompt = f"{prompt}\n\nTurn #{last_turn_idx+1}: {user_name}: {message}\n.\nYou are at Turn #{last_turn_idx+2}."
106
+ return prompt + FORMAT_TEMPLATE if use_format_guide else prompt