Wonderplex commited on
Commit
c423c55
1 Parent(s): 0c22348

adopated sotopia logics (#44)

Browse files
Files changed (4) hide show
  1. .gitignore +3 -1
  2. app.py +22 -76
  3. sotopia_pi_generate.py +248 -0
  4. utils.py +20 -6
.gitignore CHANGED
@@ -1,2 +1,4 @@
1
  __pycache__/
2
- .cache/
 
 
 
1
  __pycache__/
2
+ .cache/
3
+ openai_api.key
4
+ core
app.py CHANGED
@@ -1,25 +1,19 @@
1
  import os
2
  from collections import defaultdict
3
- from dataclasses import dataclass
4
- from uuid import uuid4
5
  import json
6
 
7
  import gradio as gr
8
- import torch
9
- import transformers
10
- from peft import PeftConfig, PeftModel, get_peft_model
11
- from transformers import (
12
- AutoModelForCausalLM,
13
- AutoTokenizer,
14
- BitsAndBytesConfig,
15
- )
16
 
17
- from utils import Environment, Agent, format_sotopia_prompt, get_starter_prompt, format_bot_message
18
  from functools import cache
 
 
 
 
19
 
20
  DEPLOYED = os.getenv("DEPLOYED", "true").lower() == "true"
21
  DEFAULT_MODEL_SELECTION = "cmu-lti/sotopia-pi-mistral-7b-BC_SR" # "mistralai/Mistral-7B-Instruct-v0.1"
22
- TEMPERATURE = 0.0
23
  TOP_P = 1
24
  MAX_TOKENS = 1024
25
 
@@ -27,6 +21,7 @@ ENVIRONMENT_PROFILES = "profiles/environment_profiles.jsonl"
27
  AGENT_PROFILES = "profiles/agent_profiles.jsonl"
28
  RELATIONSHIP_PROFILES = "profiles/relationship_profiles.jsonl"
29
 
 
30
 
31
  @cache
32
  def get_sotopia_profiles(env_file=ENVIRONMENT_PROFILES, agent_file=AGENT_PROFILES, relationship_file=RELATIONSHIP_PROFILES):
@@ -68,35 +63,6 @@ def get_sotopia_profiles(env_file=ENVIRONMENT_PROFILES, agent_file=AGENT_PROFILE
68
 
69
  return environments, environment_dict, agent_dict, relationship_dict
70
 
71
- @cache
72
- def prepare_model(model_name):
73
- compute_type = torch.float16
74
-
75
- if 'cmu-lti/sotopia-pi-mistral-7b-BC_SR'in model_name:
76
- model = AutoModelForCausalLM.from_pretrained(
77
- "mistralai/Mistral-7B-Instruct-v0.1",
78
- cache_dir="./.cache",
79
- device_map='cuda',
80
- quantization_config=BitsAndBytesConfig(
81
- load_in_4bit=True,
82
- bnb_4bit_use_double_quant=True,
83
- bnb_4bit_quant_type="nf4",
84
- bnb_4bit_compute_dtype=compute_type,
85
- )
86
- )
87
- tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
88
- model = PeftModel.from_pretrained(model, model_name).to("cuda")
89
- elif 'mistralai/Mistral-7B-Instruct-v0.1' in model_name:
90
- model = AutoModelForCausalLM.from_pretrained(
91
- "mistralai/Mistral-7B-Instruct-v0.1",
92
- cache_dir="./.cache",
93
- device_map='cuda',
94
- )
95
- tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
96
- else:
97
- raise RuntimeError(f"Model {model_name} not supported")
98
- return model, tokenizer
99
-
100
 
101
  def introduction():
102
  with gr.Column(scale=2):
@@ -162,7 +128,7 @@ def sotopia_info_accordion(accordion_visible=True):
162
  with gr.Accordion("Sotopia Information", open=accordion_visible):
163
  with gr.Column():
164
  model_name_dropdown = gr.Dropdown(
165
- choices=["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "mistralai/Mistral-7B-Instruct-v0.1", "GPT3.5"],
166
  value="cmu-lti/sotopia-pi-mistral-7b-BC_SR",
167
  interactive=True,
168
  label="Model Selection"
@@ -213,50 +179,30 @@ def instructions_accordion(instructions, according_visible=False):
213
 
214
  def chat_tab():
215
  # history are input output pairs
 
216
  def run_chat(
217
  message,
218
  history,
219
- instructions,
220
  user_agent_dropdown,
221
  bot_agent_dropdown,
222
  model_selection:str
223
  ):
224
- user_name, bot_name = user_agent_dropdown.value.name, bot_agent_dropdown.value.name
225
- model, tokenizer = prepare_model(model_selection)
226
- prompt = format_sotopia_prompt(
227
- message, history, instructions, user_name, bot_name
228
- )
229
- input_tokens = tokenizer(
230
- prompt, return_tensors="pt", padding="do_not_pad"
231
- ).input_ids.to("cuda")
232
- input_length = input_tokens.shape[-1]
233
- output_tokens = model.generate(
234
- input_tokens,
235
- temperature=TEMPERATURE,
236
- top_p=TOP_P,
237
- max_length=MAX_TOKENS,
238
- pad_token_id=tokenizer.eos_token_id,
239
- num_return_sequences=1,
240
- )
241
- output_tokens = output_tokens[:, input_length:]
242
- text_output = tokenizer.decode(
243
- output_tokens[0], skip_special_tokens=True
244
- )
245
- output = ""
246
- for _ in range(5):
247
- try:
248
- output = format_bot_message(text_output)
249
- break
250
- except Exception as e:
251
- print(e)
252
- print("Retrying...")
253
- return output
254
 
255
- _, environment_dict, agent_dict, _ = get_sotopia_profiles()
256
  with gr.Column():
257
  with gr.Row():
258
  model_name_dropdown, scenario_dropdown, user_agent_dropdown, bot_agent_dropdown = sotopia_info_accordion()
259
- starter_prompt = gr.Textbox(value=get_starter_prompt(agent_dict[user_agent_dropdown.value], agent_dict[bot_agent_dropdown.value], environment_dict[scenario_dropdown.value]), label="Modify the prompt as needed:", visible=False)
260
 
261
  with gr.Column():
262
  with gr.Blocks():
@@ -279,7 +225,7 @@ def chat_tab():
279
  rtl=False,
280
  ),
281
  additional_inputs=[
282
- starter_prompt,
283
  user_agent_dropdown,
284
  bot_agent_dropdown,
285
  model_name_dropdown,
 
1
  import os
2
  from collections import defaultdict
 
 
3
  import json
4
 
5
  import gradio as gr
 
 
 
 
 
 
 
 
6
 
7
+ from utils import Environment, Agent, get_context_prompt, dialogue_history_prompt
8
  from functools import cache
9
+ from sotopia_pi_generate import prepare_model, generate_action
10
+
11
+ with open("openai_api.key", "r") as f:
12
+ os.environ["OPENAI_API_KEY"] = f.read().strip()
13
 
14
  DEPLOYED = os.getenv("DEPLOYED", "true").lower() == "true"
15
  DEFAULT_MODEL_SELECTION = "cmu-lti/sotopia-pi-mistral-7b-BC_SR" # "mistralai/Mistral-7B-Instruct-v0.1"
16
+ TEMPERATURE = 0.7
17
  TOP_P = 1
18
  MAX_TOKENS = 1024
19
 
 
21
  AGENT_PROFILES = "profiles/agent_profiles.jsonl"
22
  RELATIONSHIP_PROFILES = "profiles/relationship_profiles.jsonl"
23
 
24
+ ACTION_TYPES = ['none', 'action', 'non-verbal communication', 'speak', 'leave']
25
 
26
  @cache
27
  def get_sotopia_profiles(env_file=ENVIRONMENT_PROFILES, agent_file=AGENT_PROFILES, relationship_file=RELATIONSHIP_PROFILES):
 
63
 
64
  return environments, environment_dict, agent_dict, relationship_dict
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  def introduction():
68
  with gr.Column(scale=2):
 
128
  with gr.Accordion("Sotopia Information", open=accordion_visible):
129
  with gr.Column():
130
  model_name_dropdown = gr.Dropdown(
131
+ choices=["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "mistralai/Mistral-7B-Instruct-v0.1", "gpt-3.5-turbo"],
132
  value="cmu-lti/sotopia-pi-mistral-7b-BC_SR",
133
  interactive=True,
134
  label="Model Selection"
 
179
 
180
  def chat_tab():
181
  # history are input output pairs
182
+ _, environment_dict, agent_dict, _ = get_sotopia_profiles()
183
  def run_chat(
184
  message,
185
  history,
186
+ environment_selection,
187
  user_agent_dropdown,
188
  bot_agent_dropdown,
189
  model_selection:str
190
  ):
191
+ environment = environment_dict[environment_selection]
192
+ user_agent = agent_dict[user_agent_dropdown]
193
+ bot_agent = agent_dict[bot_agent_dropdown]
194
+
195
+ import pdb; pdb.set_trace()
196
+ context = get_context_prompt(bot_agent, user_agent, environment)
197
+ dialogue_history, next_turn_idx = dialogue_history_prompt(message, history, user_agent, bot_agent)
198
+ prompt_history = f"{context}\n\n{dialogue_history}"
199
+ agent_action = generate_action(model_selection, prompt_history, next_turn_idx, ACTION_TYPES, bot_agent.name, TEMPERATURE)
200
+ import pdb; pdb.set_trace()
201
+ return agent_action.to_natural_language()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
 
203
  with gr.Column():
204
  with gr.Row():
205
  model_name_dropdown, scenario_dropdown, user_agent_dropdown, bot_agent_dropdown = sotopia_info_accordion()
 
206
 
207
  with gr.Column():
208
  with gr.Blocks():
 
225
  rtl=False,
226
  ),
227
  additional_inputs=[
228
+ scenario_dropdown,
229
  user_agent_dropdown,
230
  bot_agent_dropdown,
231
  model_name_dropdown,
sotopia_pi_generate.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import torch
4
+ from peft import PeftModel
5
+ from transformers import (
6
+ AutoModelForCausalLM,
7
+ AutoTokenizer,
8
+ BitsAndBytesConfig,
9
+ )
10
+
11
+ from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
12
+ from langchain_community.chat_models import ChatLiteLLM
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
14
+
15
+ from langchain.chains import LLMChain
16
+ from langchain.output_parsers import PydanticOutputParser
17
+ from langchain.prompts import (
18
+ ChatPromptTemplate,
19
+ HumanMessagePromptTemplate,
20
+ PromptTemplate,
21
+ )
22
+ from langchain.schema import BaseOutputParser, OutputParserException
23
+ from typing import TypeVar
24
+
25
+ from sotopia.messages import ActionType, AgentAction
26
+ from sotopia.utils import format_docstring
27
+ from functools import cache
28
+ import logging
29
+
30
+ OutputType = TypeVar("OutputType", bound=object)
31
+
32
+ log = logging.getLogger("generate")
33
+ # logging_handler = LoggingCallbackHandler("langchain")
34
+
35
+ def generate_action(
36
+ model_name: str,
37
+ history: str,
38
+ turn_number: int,
39
+ action_types: list[ActionType],
40
+ agent: str,
41
+ temperature: float = 0.7,
42
+ ) -> tuple[AgentAction, str]:
43
+ """
44
+ Using langchain to generate an example episode
45
+ """
46
+ try:
47
+ # Normal case, model as agent
48
+ template = """
49
+ Imagine you are {agent}, your task is to act/speak as {agent} would, keeping in mind {agent}'s social goal.
50
+ You can find {agent}'s goal (or background) in the 'Here is the context of the interaction' field.
51
+ Note that {agent}'s goal is only visible to you.
52
+ You should try your best to achieve {agent}'s goal in a way that align with their character traits.
53
+ Additionally, maintaining the conversation's naturalness and realism is essential (e.g., do not repeat what other people has already said before).
54
+ {history}.
55
+ You are at Turn #{turn_number}. Your available action types are
56
+ {action_list}.
57
+ Note: You can "leave" this conversation if 1. you have achieved your social goals, 2. this conversation makes you uncomfortable, 3. you find it uninteresting/you lose your patience, 4. or for other reasons you want to leave.
58
+
59
+ Please only generate a JSON string including the action type and the argument.
60
+ Your action should follow the given format:
61
+ {format_instructions}
62
+ """
63
+ return generate(
64
+ model_name=model_name,
65
+ template=template,
66
+ input_values=dict(
67
+ agent=agent,
68
+ turn_number=str(turn_number),
69
+ history=history,
70
+ action_list=" ".join(action_types),
71
+ ),
72
+ output_parser=PydanticOutputParser(pydantic_object=AgentAction),
73
+ temperature=temperature,
74
+ )
75
+ except Exception:
76
+ return AgentAction(action_type="none", argument=""), ""
77
+
78
+ @cache
79
+ def prepare_model(model_name):
80
+ compute_type = torch.float16
81
+
82
+ if 'cmu-lti/sotopia-pi-mistral-7b-BC_SR'in model_name:
83
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", token="REDACTED")
84
+ model = AutoModelForCausalLM.from_pretrained(
85
+ "mistralai/Mistral-7B-Instruct-v0.1",
86
+ cache_dir="./.cache",
87
+ device_map='cuda',
88
+ quantization_config=BitsAndBytesConfig(
89
+ load_in_4bit=True,
90
+ bnb_4bit_use_double_quant=True,
91
+ bnb_4bit_quant_type="nf4",
92
+ bnb_4bit_compute_dtype=compute_type,
93
+ ),
94
+ token="REDACTED"
95
+ )
96
+ model = PeftModel.from_pretrained(model, model_name).to("cuda")
97
+ else:
98
+ raise RuntimeError(f"Model {model_name} not supported")
99
+ return model, tokenizer
100
+
101
+ def obtain_chain_hf(
102
+ model_name: str,
103
+ template: str,
104
+ input_variables: list[str],
105
+ temperature: float = 0.7,
106
+ max_retries: int = 6,
107
+ max_tokens: int = 2700
108
+ ) -> LLMChain:
109
+ human_message_prompt = HumanMessagePromptTemplate(
110
+ prompt=PromptTemplate(template=template, input_variables=input_variables)
111
+ )
112
+ chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt])
113
+ model, tokenizer = prepare_model(model_name)
114
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=max_tokens, temperature=temperature)
115
+ hf = HuggingFacePipeline(pipeline=pipe)
116
+ import pdb; pdb.set_trace()
117
+ chain = LLMChain(llm=hf, prompt=chat_prompt_template)
118
+ return chain
119
+
120
+ def generate(
121
+ model_name: str,
122
+ template: str,
123
+ input_values: dict[str, str],
124
+ output_parser: BaseOutputParser[OutputType],
125
+ temperature: float = 0.7,
126
+ ) -> tuple[OutputType, str]:
127
+ import pdb; pdb.set_trace()
128
+ input_variables = re.findall(r"{(.*?)}", template)
129
+ assert (
130
+ set(input_variables) == set(list(input_values.keys()) + ["format_instructions"])
131
+ or set(input_variables) == set(list(input_values.keys()))
132
+ ), f"The variables in the template must match input_values except for format_instructions. Got {sorted(input_values.keys())}, expect {sorted(input_variables)}"
133
+ # process template
134
+ template = format_docstring(template)
135
+ chain = obtain_chain(model_name, template, input_variables, temperature)
136
+ if "format_instructions" not in input_values:
137
+ input_values["format_instructions"] = output_parser.get_format_instructions()
138
+ result = chain.predict([], **input_values)
139
+ import pdb; pdb.set_trace()
140
+ try:
141
+ parsed_result = output_parser.parse(result)
142
+ except KeyboardInterrupt:
143
+ raise KeyboardInterrupt
144
+ except Exception as e:
145
+ log.debug(
146
+ f"[red] Failed to parse result: {result}\nEncounter Exception {e}\nstart to reparse",
147
+ extra={"markup": True},
148
+ )
149
+ reformat_parsed_result = format_bad_output(
150
+ result, format_instructions=output_parser.get_format_instructions()
151
+ )
152
+ parsed_result = output_parser.parse(reformat_parsed_result)
153
+ log.info(f"Generated result: {parsed_result}")
154
+ return parsed_result
155
+
156
+ def format_bad_output(
157
+ ill_formed_output: str,
158
+ format_instructions: str,
159
+ model_name: str = "gpt-3.5-turbo",
160
+ ) -> str:
161
+ template = """
162
+ Given the string that can not be parsed by json parser, reformat it to a string that can be parsed by json parser.
163
+ Original string: {ill_formed_output}
164
+
165
+ Format instructions: {format_instructions}
166
+
167
+ Please only generate the JSON:
168
+ """
169
+ chain = obtain_chain(
170
+ model_name=model_name,
171
+ template=template,
172
+ input_variables=re.findall(r"{(.*?)}", template),
173
+ )
174
+ input_values = {
175
+ "ill_formed_output": ill_formed_output,
176
+ "format_instructions": format_instructions,
177
+ }
178
+ reformat = chain.predict([], **input_values)
179
+ log.info(f"Reformated output: {reformat}")
180
+ return reformat
181
+
182
+ def obtain_chain(
183
+ model_name: str,
184
+ template: str,
185
+ input_variables: list[str],
186
+ temperature: float = 0.7,
187
+ max_retries: int = 6,
188
+ ) -> LLMChain:
189
+ """
190
+ Using langchain to sample profiles for participants
191
+ """
192
+ if model_name in ["cmu-lti/sotopia-pi-mistral-7b-BC_SR"]:
193
+ return obtain_chain_hf(
194
+ model_name=model_name,
195
+ template=template,
196
+ input_variables=input_variables,
197
+ temperature=temperature,
198
+ max_retries=max_retries,
199
+ )
200
+
201
+ model_name = _return_fixed_model_version(model_name)
202
+ chat = ChatLiteLLM(
203
+ model=model_name,
204
+ temperature=temperature,
205
+ max_tokens=2700, # tweak as needed
206
+ max_retries=max_retries,
207
+ )
208
+ human_message_prompt = HumanMessagePromptTemplate(
209
+ prompt=PromptTemplate(template=template, input_variables=input_variables)
210
+ )
211
+ chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt])
212
+ chain = LLMChain(llm=chat, prompt=chat_prompt_template)
213
+ return chain
214
+
215
+ def format_bad_output(
216
+ ill_formed_output: str,
217
+ format_instructions: str,
218
+ model_name: str = "gpt-3.5-turbo",
219
+ ) -> str:
220
+ template = """
221
+ Given the string that can not be parsed by json parser, reformat it to a string that can be parsed by json parser.
222
+ Original string: {ill_formed_output}
223
+
224
+ Format instructions: {format_instructions}
225
+
226
+ Please only generate the JSON:
227
+ """
228
+ chain = obtain_chain(
229
+ model_name=model_name,
230
+ template=template,
231
+ input_variables=re.findall(r"{(.*?)}", template),
232
+ )
233
+ input_values = {
234
+ "ill_formed_output": ill_formed_output,
235
+ "format_instructions": format_instructions,
236
+ }
237
+ reformat = chain.predict([], **input_values)
238
+ log.info(f"Reformated output: {reformat}")
239
+ return reformat
240
+
241
+ def _return_fixed_model_version(model_name: str) -> str:
242
+ return {
243
+ "gpt-3.5-turbo": "gpt-3.5-turbo-0613",
244
+ "gpt-3.5-turbo-finetuned": "ft:gpt-3.5-turbo-0613:academicscmu::8nY2zgdt",
245
+ "gpt-3.5-turbo-ft-MF": "ft:gpt-3.5-turbo-0613:academicscmu::8nuER4bO",
246
+ "gpt-4": "gpt-4-0613",
247
+ "gpt-4-turbo": "gpt-4-1106-preview",
248
+ }[model_name]
utils.py CHANGED
@@ -44,7 +44,10 @@ def get_format_guide():
44
  """
45
 
46
  def get_starter_prompt(machine_agent, human_agent, environment):
47
- return f"Prompt after formatting:\nImagine you are {machine_agent.name}, your task is to act/speak as {machine_agent.name} would, keeping in mind {machine_agent.name}'s social goal.\nYou can find {machine_agent.name}'s background and goal in the 'Here is the context of the interaction' field.\nNote that {machine_agent.name}'s secret and goal is only visible to you.\nYou should try your best to achieve {machine_agent.name}'s goal in a way that align with their character traits.\nAdditionally, maintaining the conversation's naturalness and realism is essential (e.g., do not repeat what other people has already said before).\n\nHere is the context of this interaction:\n Scenario: {environment.scenario}\nParticipants: {human_agent.name} and {machine_agent.name}\n{human_agent.name}'s background: {human_agent.background} Personality and values description: {human_agent.personality} \n{machine_agent.name}'s background: {machine_agent.background} Personality and values description: {machine_agent.personality} {machine_agent.name}'s secrets: {machine_agent.secret}\n{human_agent.name}'s goal: Unknown\n{machine_agent.name}'s goal: {environment.agent_goals[1]}\nConversation Starts:"
 
 
 
48
 
49
 
50
  # we define history as
@@ -102,6 +105,20 @@ def dialogue_history_creation(history, user_name, bot_name):
102
  last_turn_idx = len(history) * 2
103
  return dialogue_history, last_turn_idx
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  def dialogue_history_truncation(dialogue_history, max_token_num, tokenizer):
107
  surpass_num = dialogue_history_length_check(
@@ -114,15 +131,12 @@ def dialogue_history_truncation(dialogue_history, max_token_num, tokenizer):
114
  return dialogue_history
115
 
116
 
117
- def format_sotopia_prompt(
118
  message: str,
119
  history: List[Tuple[str, str]],
120
  instructions: str,
121
  user_name: str,
122
  bot_name: str,
123
- include_all_chat_history: bool = True,
124
- index: int = 1,
125
- use_format_guide: bool = True,
126
  ) -> str:
127
  prompt = instructions.strip()
128
  dialogue_history, last_turn_idx = dialogue_history_creation(
@@ -130,4 +144,4 @@ def format_sotopia_prompt(
130
  )
131
  prompt = f"{prompt}\n{dialogue_history}"
132
  prompt = f"{prompt}\n\nTurn #{last_turn_idx+1}: {user_name}: {message}\n.\nYou are at Turn #{last_turn_idx+2}."
133
- return prompt + get_format_guide() if use_format_guide else prompt
 
44
  """
45
 
46
  def get_starter_prompt(machine_agent, human_agent, environment):
47
+ return f"Imagine you are {machine_agent.name}, your task is to act/speak as {machine_agent.name} would, keeping in mind {machine_agent.name}'s social goal.\nYou can find {machine_agent.name}'s background and goal in the 'Here is the context of the interaction' field.\nNote that {machine_agent.name}'s secret and goal is only visible to you.\nYou should try your best to achieve {machine_agent.name}'s goal in a way that align with their character traits.\nAdditionally, maintaining the conversation's naturalness and realism is essential (e.g., do not repeat what other people has already said before).\n\nHere is the context of this interaction:\n Scenario: {environment.scenario}\nParticipants: {human_agent.name} and {machine_agent.name}\n{human_agent.name}'s background: {human_agent.background} Personality and values description: {human_agent.personality} \n{machine_agent.name}'s background: {machine_agent.background} Personality and values description: {machine_agent.personality} {machine_agent.name}'s secrets: {machine_agent.secret}\n{human_agent.name}'s goal: Unknown\n{machine_agent.name}'s goal: {environment.agent_goals[1]}\nConversation Starts:"
48
+
49
+ def get_context_prompt(machine_agent, human_agent, environment):
50
+ return f"Here is the context of this interaction:\n Scenario: {environment.scenario}\nParticipants: {human_agent.name} and {machine_agent.name}\n{human_agent.name}'s background: {human_agent.background} Personality and values description: {human_agent.personality} \n{machine_agent.name}'s background: {machine_agent.background} Personality and values description: {machine_agent.personality} {machine_agent.name}'s secrets: {machine_agent.secret}\n{human_agent.name}'s goal: Unknown\n{machine_agent.name}'s goal: {environment.agent_goals[1]}\nConversation Starts:"
51
 
52
 
53
  # we define history as
 
105
  last_turn_idx = len(history) * 2
106
  return dialogue_history, last_turn_idx
107
 
108
+ def dialogue_history_prompt(message, history, user_agent, bot_agent):
109
+ dialogue_history = ""
110
+ for idx, turn in enumerate(history):
111
+ user_message, bot_message = turn
112
+ # TODOTODO (haofeiyu): we first assume that human talks first
113
+ user_turn_idx = idx * 2
114
+ bot_turn_idx = idx * 2 + 1
115
+ if not bot_message.startswith("["): # if action type == speak, need to add 'said: ' to be consistent with the dialog prompt
116
+ bot_message = "said :" + bot_message
117
+ dialogue_history = f"{dialogue_history}\n\nTurn #{user_turn_idx}: {user_agent.name}: {user_message}\n\nTurn #{bot_turn_idx}: {bot_agent.name}: {bot_message}"
118
+ last_turn_idx = len(history) * 2
119
+ dialogue_history = f"{dialogue_history}\n\nTurn #{last_turn_idx+1}: {user_agent.name}: {message}\n."
120
+ return dialogue_history, last_turn_idx+2
121
+
122
 
123
  def dialogue_history_truncation(dialogue_history, max_token_num, tokenizer):
124
  surpass_num = dialogue_history_length_check(
 
131
  return dialogue_history
132
 
133
 
134
+ def format_hostory_prompt(
135
  message: str,
136
  history: List[Tuple[str, str]],
137
  instructions: str,
138
  user_name: str,
139
  bot_name: str,
 
 
 
140
  ) -> str:
141
  prompt = instructions.strip()
142
  dialogue_history, last_turn_idx = dialogue_history_creation(
 
144
  )
145
  prompt = f"{prompt}\n{dialogue_history}"
146
  prompt = f"{prompt}\n\nTurn #{last_turn_idx+1}: {user_name}: {message}\n.\nYou are at Turn #{last_turn_idx+2}."
147
+ return prompt