Wonderplex commited on
Commit
fe95067
1 Parent(s): c3a4051

Feature/hf pipeline format (#47)

Browse files

* removed older version utils

* changed requirements

* added changes for inference

Files changed (7) hide show
  1. .gitignore +1 -0
  2. app.py +2 -2
  3. langchain_callback_handler.py +60 -0
  4. message_classes.py +343 -0
  5. requirements.txt +148 -6
  6. sotopia_pi_generate.py +57 -50
  7. utils.py +7 -99
.gitignore CHANGED
@@ -1,4 +1,5 @@
1
  __pycache__/
2
  .cache/
3
  openai_api.key
 
4
  core
 
1
  __pycache__/
2
  .cache/
3
  openai_api.key
4
+ hf_token.key
5
  core
app.py CHANGED
@@ -12,7 +12,7 @@ with open("openai_api.key", "r") as f:
12
  os.environ["OPENAI_API_KEY"] = f.read().strip()
13
 
14
  DEPLOYED = os.getenv("DEPLOYED", "true").lower() == "true"
15
- DEFAULT_MODEL_SELECTION = "gpt-3.5-turbo" # "mistralai/Mistral-7B-Instruct-v0.1"
16
  TEMPERATURE = 0.7
17
  TOP_P = 1
18
  MAX_TOKENS = 1024
@@ -147,7 +147,7 @@ def sotopia_info_accordion(accordion_visible=True):
147
  interactive=True,
148
  )
149
  model_name_dropdown = gr.Dropdown(
150
- choices=["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "mistralai/Mistral-7B-Instruct-v0.1", "gpt-3.5-turbo", "gpt-4-turbo"],
151
  value=DEFAULT_MODEL_SELECTION,
152
  interactive=True,
153
  label="Model Selection"
 
12
  os.environ["OPENAI_API_KEY"] = f.read().strip()
13
 
14
  DEPLOYED = os.getenv("DEPLOYED", "true").lower() == "true"
15
+ DEFAULT_MODEL_SELECTION = "gpt-3.5-turbo"
16
  TEMPERATURE = 0.7
17
  TOP_P = 1
18
  MAX_TOKENS = 1024
 
147
  interactive=True,
148
  )
149
  model_name_dropdown = gr.Dropdown(
150
+ choices=["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit", "mistralai/Mistral-7B-Instruct-v0.1", "gpt-3.5-turbo", "gpt-4-turbo"],
151
  value=DEFAULT_MODEL_SELECTION,
152
  interactive=True,
153
  label="Model Selection"
langchain_callback_handler.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any
3
+
4
+ from langchain.callbacks import StdOutCallbackHandler
5
+
6
+ logging.addLevelName(15, "LangChain")
7
+
8
+
9
+ class LoggingCallbackHandler(StdOutCallbackHandler):
10
+ """Callback Handler that prints to std out."""
11
+
12
+ always_verbose = True
13
+
14
+ def __init__(self, name: str) -> None:
15
+ """Initialize callback handler."""
16
+ super().__init__()
17
+ self.logger = logging.getLogger(name)
18
+ self.prompt = ""
19
+
20
+ def on_chain_start(self, *args: Any, **kwargs: Any) -> None:
21
+ pass
22
+
23
+ def on_chain_end(self, *args: Any, **kwargs: Any) -> None:
24
+ pass
25
+
26
+ def on_agent_action(self, *args: Any, **kwargs: Any) -> Any:
27
+ pass
28
+
29
+ def on_tool_end(
30
+ self,
31
+ *args: Any,
32
+ **kwargs: Any,
33
+ ) -> None:
34
+ pass
35
+
36
+ def on_tool_error(
37
+ self, error: BaseException | KeyboardInterrupt, **kwargs: Any
38
+ ) -> None:
39
+ """Do nothing."""
40
+ pass
41
+
42
+ def on_text(
43
+ self,
44
+ text: str,
45
+ color: str | None = None,
46
+ end: str = "",
47
+ **kwargs: Any,
48
+ ) -> None:
49
+ """Run when agent ends."""
50
+ # leave only prompt for environment
51
+ text = text.replace("\x1b[32;1m\x1b[1;3mHuman: ", "")
52
+ logging.log(15, f"LLM Call: {text}")
53
+ self.prompt = text
54
+
55
+ def retrive_prompt(self) -> str:
56
+ return self.prompt
57
+
58
+ def on_agent_finish(self, *args: Any, **kwargs: Any) -> None:
59
+ """Run on agent end."""
60
+ pass
message_classes.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Literal, cast
3
+
4
+ from pydantic import BaseModel, Field
5
+
6
+ from utils import format_docstring
7
+
8
+ ActionType = Literal["none", "speak", "non-verbal communication", "action", "leave"]
9
+
10
+
11
+ class Message(BaseModel):
12
+ """
13
+ An interface for messages.
14
+ There is only one required method: to_natural_language
15
+ """
16
+
17
+ def to_natural_language(self) -> str:
18
+ raise NotImplementedError
19
+
20
+
21
+ class SimpleMessage(Message):
22
+ """
23
+ A simple message with a single string field.
24
+ """
25
+
26
+ message: str = Field(description="the message")
27
+
28
+ def to_natural_language(self) -> str:
29
+ return self.message
30
+
31
+
32
+ class Observation(Message):
33
+ last_turn: str = Field(description="the last turn of the conversation")
34
+ turn_number: int = Field(description="the turn number of the conversation")
35
+ available_actions: list[ActionType] = Field(description="the available actions")
36
+
37
+ def to_natural_language(self) -> str:
38
+ if self.turn_number == 0:
39
+ return f"\n{self.last_turn}\nConversation Starts:\n"
40
+ else:
41
+ return f"Turn #{self.turn_number-1}: {self.last_turn}\n"
42
+
43
+
44
+ class ScriptBackground(Message):
45
+ scenario: str = Field(description="scenario of the episode")
46
+ p1_name: str = Field(description="name of participant 1")
47
+ p2_name: str = Field(description="name of participant 2")
48
+ p1_background: str = Field(description="background of participant 1")
49
+ p2_background: str = Field(description="background of participant 2")
50
+ p1_goal: str = Field(description="goal of participant 1")
51
+ p2_goal: str = Field(description="goal of participant 2")
52
+
53
+ def to_natural_language(self) -> str:
54
+ if self.p1_background or self.p2_background:
55
+ p1_background = self.p1_background if self.p1_background else "Unknown"
56
+ p2_background = self.p2_background if self.p2_background else "Unknown"
57
+ # Not using AND, since in stranger relation the background is not visible
58
+ return format_docstring(
59
+ f"""Here is the context of this interaction:
60
+ Scenario: {self.scenario}
61
+ Participants: {self.p1_name} and {self.p2_name}
62
+ {self.p1_name}'s background: {p1_background}
63
+ {self.p2_name}'s background: {p2_background}
64
+ {self.p1_name}'s goal: {self.p1_goal}
65
+ {self.p2_name}'s goal: {self.p2_goal}
66
+ """
67
+ )
68
+ else:
69
+ return format_docstring(
70
+ f"""Here is the context of this interaction:
71
+ Scenario: {self.scenario}
72
+ Participants: {self.p1_name} and {self.p2_name}
73
+ {self.p1_name}'s goal: {self.p1_goal}
74
+ {self.p2_name}'s goal: {self.p2_goal}
75
+ """
76
+ )
77
+
78
+
79
+ class ScriptEnvironmentResponse(Message):
80
+ terminated: bool = Field(
81
+ description="whether the conversation is terminated",
82
+ default_factory=lambda: False,
83
+ )
84
+ p1_rate: float | tuple[float, dict[str, float]] | None = Field(
85
+ description="rating of participant 1, on the scale of 1 to 10"
86
+ )
87
+ p2_rate: float | tuple[float, dict[str, float]] | None = Field(
88
+ description="rating of participant 2, on the scale of 1 to 10"
89
+ )
90
+ comments: str | None = Field(
91
+ description="All of the comments supporting the termination and rating"
92
+ )
93
+
94
+ def to_natural_language(self) -> str:
95
+ reason_to_stop = format_docstring(
96
+ f"""Environment response:
97
+ {"The conversation is terminated." if self.terminated else ""}
98
+ {"Rating of participant 1" + str(self.p1_rate) if self.p1_rate is not None else ""}
99
+ {"Rating of participant 2" + str(self.p2_rate) if self.p2_rate is not None else ""}
100
+ {self.comments if self.comments is not None else ""}
101
+ """
102
+ )
103
+ clean_text = ""
104
+ for line in reason_to_stop.split("\n"):
105
+ if line.strip():
106
+ clean_text += line + "\n"
107
+ return clean_text
108
+
109
+
110
+ class AgentAction(Message):
111
+ action_type: ActionType = Field(
112
+ description="whether to speak at this turn or choose to not do anything"
113
+ )
114
+ argument: str = Field(
115
+ description="the utterance if choose to speak, the expression or gesture if choose non-verbal communication, or the physical action if choose action"
116
+ )
117
+
118
+ def to_natural_language(self) -> str:
119
+ match self.action_type:
120
+ case "none":
121
+ return "did nothing"
122
+ case "speak":
123
+ return f'said: "{self.argument}"'
124
+ case "non-verbal communication":
125
+ return f"[{self.action_type}] {self.argument}"
126
+ case "action":
127
+ return f"[{self.action_type}] {self.argument}"
128
+ case "leave":
129
+ return "left the conversation"
130
+
131
+
132
+ ScriptInteractionReturnType = tuple[
133
+ list[list[tuple[str, str, Message]]], list[tuple[str, Message]]
134
+ ]
135
+
136
+
137
+ class ScriptInteraction(Message):
138
+ interactions: str = Field(
139
+ description="""The interaction between the two participants in maximum 20 turns. Each turn is separated by a newline, and should only describe one agent. Following the structure:
140
+ Turn #x
141
+ [participant's name] [action] {argument for some actions}
142
+
143
+ You can use different types of actions, but only use one in each turn. You should move other information into argument part. Below shows a python code snippet of the format for each action type:
144
+ match self.action_type:
145
+ case "none":
146
+ return "did nothing"
147
+ case "speak":
148
+ return f'said: "{self.argument}"'
149
+ case "non-verbal communication":
150
+ return f"[{self.action_type}] {self.argument}"
151
+ case "action":
152
+ return f"[{self.action_type}] {self.argument}"
153
+ case "leave":
154
+ return "left the conversation"
155
+
156
+ For example, the following is acceptable:
157
+ Turn #x
158
+ Oliver Thompson said: "Hey Esmeralda, what's wrong? You seem upset."
159
+ Turn #x
160
+ Esmeralda Solis [action] moved closer
161
+ Turn #x
162
+ Oliver Thompson [non-verbal communication] smiled
163
+ Turn #x
164
+ Esmeralda Solis did nothing
165
+ Turn #x
166
+ Oliver Thompson left the conversation
167
+ Turn #x
168
+ Esmeralda Solis [action] leaned in and lowered her voice: "Sorry"
169
+
170
+ And the following is not acceptable:
171
+ Turn #1
172
+ Oliver Thompson [speak] said: "Hey Esmeralda, what's wrong? You seem upset."
173
+ Turn #1
174
+ Esmeralda Solis non-verbal communication moved closer
175
+ """
176
+ )
177
+
178
+ def to_natural_language(self) -> str:
179
+ return self.interactions
180
+
181
+ def parse(
182
+ self, agent_names: list[str], background: str
183
+ ) -> tuple[list[list[tuple[str, str, Message]]], list[tuple[str, Message]]]:
184
+ interaction = self.interactions
185
+ # print("Interaction: ", interaction)
186
+ lines = self.split_by_turn(interaction)
187
+
188
+ agent_results = []
189
+ results: list[list[tuple[str, str, Message]]] = [
190
+ [
191
+ (
192
+ "Environment",
193
+ name,
194
+ Observation(
195
+ last_turn=background,
196
+ turn_number=0,
197
+ available_actions=["none"],
198
+ ),
199
+ )
200
+ for name in agent_names
201
+ ]
202
+ ]
203
+
204
+ for line_idx, line in enumerate(lines):
205
+ try:
206
+ res = self.parse_single_dialogue(line)
207
+ action: AgentAction = cast(AgentAction, res["action"])
208
+ argument: str = cast(str, res["argument"])
209
+ cast(int, res["turn"])
210
+ name: str = cast(str, res["name"])
211
+
212
+ parsed_action = AgentAction(action_type=action, argument=argument)
213
+ if name not in agent_names:
214
+ print(
215
+ f"The name of the agent, {name}, is not in the list of agent names, {agent_names}"
216
+ )
217
+ name = agent_names[
218
+ line_idx % 2
219
+ ] # TODO Not sure what name to be set here
220
+ except Exception as e:
221
+ print(
222
+ f"Error when parsing the dialogue: {line}",
223
+ f"The error is: {e}",
224
+ )
225
+ raise e
226
+ parsed_action = AgentAction(action_type="none", argument="")
227
+ name = agent_names[line_idx % 2] # TODO same question as above
228
+ inactive_agent_name = (
229
+ agent_names[0] if name == agent_names[1] else agent_names[1]
230
+ )
231
+ results.append(
232
+ [
233
+ (
234
+ "Environment",
235
+ name,
236
+ Observation(
237
+ last_turn="environment is the agent",
238
+ turn_number=line_idx + 1,
239
+ available_actions=["none"],
240
+ ),
241
+ )
242
+ for name in agent_names
243
+ ]
244
+ + [
245
+ (name, "Environment", parsed_action),
246
+ (
247
+ inactive_agent_name,
248
+ "Environment",
249
+ AgentAction(action_type="none", argument="did nothing"),
250
+ ),
251
+ ]
252
+ )
253
+
254
+ agent_results.append((name, parsed_action))
255
+ # print("Parsed agent results: ", agent_results)
256
+ return (results, agent_results) # type: ignore
257
+
258
+ def parse_single_dialogue(
259
+ self, dialogue: str
260
+ ) -> dict[str, str | int | AgentAction | None]:
261
+ """Parse a single dialogue string and return a dictionary with turn, name, action, and argument."""
262
+
263
+ # Match the turn number and name. Assume all agent name starts with a capital letter and is followed by lowercase letters
264
+ match_turn_name = re.match(
265
+ r"Turn #?(\d+):?\s*\n((?:[A-Z]['a-z]* ?)+)", dialogue
266
+ )
267
+
268
+ if not match_turn_name:
269
+ raise ValueError(
270
+ f"The dialogue does not match the expected format: {dialogue}"
271
+ )
272
+ return None # TODO Which should we use, return None or raise error?
273
+
274
+ turn, name = match_turn_name.groups()
275
+ action_content = dialogue[
276
+ len(match_turn_name.group(0)) :
277
+ ].strip() # Extract the action content
278
+
279
+ # Check for different action types
280
+ if "did nothing" in action_content:
281
+ action, argument = "none", ""
282
+ elif match := re.match(r'said: "(.*?)"', action_content):
283
+ action, argument = "speak", match.group(1)
284
+ action, argument = action.strip(), argument.strip()
285
+ elif match := re.match(r'\[speak\] said: "(.*?)"', action_content):
286
+ action, argument = "speak", match.group(1)
287
+ action, argument = action.strip(), argument.strip()
288
+ elif match := re.match(
289
+ r"\[(non-verbal communication|action)\] (.*)", action_content
290
+ ):
291
+ action, argument = match.groups()
292
+ elif "left the conversation" in action_content:
293
+ # TODO Make it more elegant to handle the situation of `left the conversation.`
294
+ action, argument = "leave", ""
295
+ else:
296
+ action, argument = None, None
297
+
298
+ parsed_item = {
299
+ "turn": int(turn),
300
+ "name": name.strip(),
301
+ "action": action,
302
+ "argument": argument,
303
+ }
304
+ return parsed_item
305
+
306
+ def split_by_turn(self, input_string: str) -> list[str]:
307
+ """Split the input dialogue string by turn and return a list of dialogues."""
308
+ # Split using 'Turn #' as delimiter, but keep the delimiter in the results
309
+ dialogues = re.split(r"(?=Turn #?\d+)", input_string)
310
+ # Remove any empty strings and strip whitespace
311
+ dialogues = [dialogue.strip() for dialogue in dialogues if dialogue.strip()]
312
+ dialogues = [dialogue for dialogue in dialogues if dialogue.startswith("Turn")]
313
+ # Change from Turn #x to Turn (#)x (# is optional)
314
+ dialogues[-1] = "\n".join(
315
+ dialogues[-1].split("\n")[:2]
316
+ ) # Discard further input in the last turn
317
+
318
+ for dialogue in dialogues:
319
+ # TODO this is current workaround for the issue of multiple agents in one turn
320
+ if len(dialogue.split("\n")) >= 3:
321
+ raise ValueError("Only one agent can act per turn.")
322
+ return dialogues
323
+
324
+ @staticmethod
325
+ def default_value_for_return_type() -> ScriptInteractionReturnType:
326
+ results_1: list[list[tuple[str, str, Message]]] = [
327
+ [
328
+ (
329
+ "Environment",
330
+ name,
331
+ Observation(
332
+ last_turn="Environment is the agent",
333
+ turn_number=0,
334
+ available_actions=["none"],
335
+ ),
336
+ )
337
+ for name in ["none", "none"]
338
+ ]
339
+ ]
340
+ results_2: list[tuple[str, Message]] = [
341
+ ("", AgentAction(action_type="none", argument=""))
342
+ ]
343
+ return (results_1, results_2)
requirements.txt CHANGED
@@ -1,6 +1,148 @@
1
- sotopia
2
- gradio
3
- transformers
4
- torch
5
- peft
6
- bitsandbytes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ accelerate==0.29.3
3
+ aiofiles==23.2.1
4
+ aiohttp==3.9.5
5
+ aiosignal==1.3.1
6
+ altair==5.3.0
7
+ annotated-types==0.6.0
8
+ anyio==3.7.1
9
+ attrs==23.2.0
10
+ beartype==0.14.1
11
+ bitsandbytes==0.43.1
12
+ certifi==2024.2.2
13
+ cffi==1.16.0
14
+ charset-normalizer==3.3.2
15
+ click==8.1.7
16
+ cloudpickle==3.0.0
17
+ contourpy==1.2.1
18
+ cryptography==42.0.5
19
+ cycler==0.12.1
20
+ dataclasses-json==0.6.4
21
+ datasets==2.18.0
22
+ dill==0.3.8
23
+ distro==1.9.0
24
+ Farama-Notifications==0.0.4
25
+ ffmpy==0.3.2
26
+ filelock==3.13.4
27
+ fonttools==4.51.0
28
+ frozenlist==1.4.1
29
+ fsspec==2024.2.0
30
+ gin-config==0.5.0
31
+ gradio==4.27.0
32
+ gradio_client==0.15.1
33
+ greenlet==3.0.3
34
+ gymnasium==0.29.1
35
+ h11==0.14.0
36
+ hiredis==2.3.2
37
+ httpcore==1.0.5
38
+ httpx==0.27.0
39
+ huggingface-hub==0.22.2
40
+ idna==3.7
41
+ importlib_metadata==7.1.0
42
+ importlib_resources==6.4.0
43
+ Jinja2==3.1.3
44
+ jsonpatch==1.33
45
+ jsonpointer==2.4
46
+ jsonschema==4.21.1
47
+ jsonschema-specifications==2023.12.1
48
+ kiwisolver==1.4.5
49
+ langchain==0.1.16
50
+ langchain-community==0.0.33
51
+ langchain-core==0.1.44
52
+ langchain-openai==0.0.5
53
+ langchain-text-splitters==0.0.1
54
+ langsmith==0.1.48
55
+ litellm==1.35.12
56
+ lxml==4.9.4
57
+ markdown-it-py==3.0.0
58
+ MarkupSafe==2.1.5
59
+ marshmallow==3.21.1
60
+ matplotlib==3.8.4
61
+ mdurl==0.1.2
62
+ more-itertools==10.2.0
63
+ mpmath==1.3.0
64
+ multidict==6.0.5
65
+ multiprocess==0.70.16
66
+ mypy==1.9.0
67
+ mypy-extensions==1.0.0
68
+ names==0.3.0
69
+ networkx==3.3
70
+ numpy==1.26.4
71
+ nvidia-cublas-cu12==12.1.3.1
72
+ nvidia-cuda-cupti-cu12==12.1.105
73
+ nvidia-cuda-nvrtc-cu12==12.1.105
74
+ nvidia-cuda-runtime-cu12==12.1.105
75
+ nvidia-cudnn-cu12==8.9.2.26
76
+ nvidia-cufft-cu12==11.0.2.54
77
+ nvidia-curand-cu12==10.3.2.106
78
+ nvidia-cusolver-cu12==11.4.5.107
79
+ nvidia-cusparse-cu12==12.1.0.106
80
+ nvidia-nccl-cu12==2.19.3
81
+ nvidia-nvjitlink-cu12==12.4.127
82
+ nvidia-nvtx-cu12==12.1.105
83
+ openai==1.22.0
84
+ orjson==3.10.1
85
+ packaging==23.2
86
+ pandas==2.2.2
87
+ pandas-stubs==2.2.1.240316
88
+ peft==0.10.0
89
+ pettingzoo==1.24.0
90
+ pillow==10.3.0
91
+ psutil==5.9.8
92
+ pyarrow==15.0.2
93
+ pyarrow-hotfix==0.6
94
+ pycparser==2.22
95
+ pydantic==2.7.0
96
+ pydantic_core==2.18.1
97
+ pydub==0.25.1
98
+ Pygments==2.17.2
99
+ pyparsing==3.1.2
100
+ python-dateutil==2.9.0.post0
101
+ python-dotenv==1.0.1
102
+ python-multipart==0.0.9
103
+ python-ulid==1.1.0
104
+ pytz==2024.1
105
+ PyYAML==6.0.1
106
+ redis==5.0.3
107
+ referencing==0.34.0
108
+ regex==2024.4.16
109
+ requests==2.31.0
110
+ rich==13.7.1
111
+ rpds-py==0.18.0
112
+ ruff==0.3.7
113
+ safetensors==0.4.3
114
+ scipy==1.13.0
115
+ semantic-version==2.10.0
116
+ shellingham==1.5.4
117
+ six==1.16.0
118
+ sniffio==1.3.1
119
+ SQLAlchemy==2.0.29
120
+ sseclient-py==1.8.0
121
+ starlette==0.27.0
122
+ sympy==1.12
123
+ tabulate==0.9.0
124
+ tenacity==8.2.3
125
+ tiktoken==0.5.2
126
+ tokenizers==0.19.1
127
+ tomlkit==0.12.0
128
+ toolz==0.12.1
129
+ torch==2.2.2
130
+ tqdm==4.66.2
131
+ transformers==4.40.0
132
+ triton==2.2.0
133
+ typer==0.12.3
134
+ types-cffi==1.16.0.20240331
135
+ types-pyOpenSSL==24.0.0.20240417
136
+ types-pytz==2024.1.0.20240417
137
+ types-redis==4.6.0.20240417
138
+ types-setuptools==69.5.0.20240415
139
+ types-tqdm==4.66.0.20240417
140
+ typing-inspect==0.9.0
141
+ typing_extensions==4.11.0
142
+ tzdata==2024.1
143
+ urllib3==2.2.1
144
+ uvicorn==0.23.2
145
+ websockets==11.0.3
146
+ xxhash==3.4.1
147
+ yarl==1.9.4
148
+ zipp==3.18.1
sotopia_pi_generate.py CHANGED
@@ -1,17 +1,18 @@
1
  import re
 
 
 
2
 
3
  import torch
4
- from peft import PeftModel
5
  from transformers import (
6
  AutoModelForCausalLM,
7
  AutoTokenizer,
8
  BitsAndBytesConfig,
9
  )
10
-
11
  from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
12
  from langchain_community.chat_models import ChatLiteLLM
13
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
14
-
15
  from langchain.chains import LLMChain
16
  from langchain.output_parsers import PydanticOutputParser
17
  from langchain.prompts import (
@@ -20,17 +21,16 @@ from langchain.prompts import (
20
  PromptTemplate,
21
  )
22
  from langchain.schema import BaseOutputParser, OutputParserException
23
- from typing import TypeVar
 
24
 
25
- from sotopia.messages import ActionType, AgentAction
26
- from sotopia.utils import format_docstring
27
- from functools import cache
28
- import logging
29
 
30
- OutputType = TypeVar("OutputType", bound=object)
31
 
 
32
  log = logging.getLogger("generate")
33
- # logging_handler = LoggingCallbackHandler("langchain")
34
 
35
  def generate_action(
36
  model_name: str,
@@ -39,7 +39,7 @@ def generate_action(
39
  action_types: list[ActionType],
40
  agent: str,
41
  temperature: float = 0.7,
42
- ) -> tuple[AgentAction, str]:
43
  """
44
  Using langchain to generate an example episode
45
  """
@@ -73,14 +73,26 @@ def generate_action(
73
  temperature=temperature,
74
  )
75
  except Exception:
76
- return AgentAction(action_type="none", argument=""), ""
77
 
78
  @cache
79
- def prepare_model(model_name):
80
  compute_type = torch.float16
 
 
81
 
82
- if 'cmu-lti/sotopia-pi-mistral-7b-BC_SR'in model_name:
83
- tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", token="REDACTED")
 
 
 
 
 
 
 
 
 
 
84
  model = AutoModelForCausalLM.from_pretrained(
85
  "mistralai/Mistral-7B-Instruct-v0.1",
86
  cache_dir="./.cache",
@@ -91,11 +103,22 @@ def prepare_model(model_name):
91
  bnb_4bit_quant_type="nf4",
92
  bnb_4bit_compute_dtype=compute_type,
93
  ),
94
- token="REDACTED"
95
  )
96
  model = PeftModel.from_pretrained(model, model_name).to("cuda")
 
 
 
 
 
 
 
 
 
 
97
  else:
98
  raise RuntimeError(f"Model {model_name} not supported")
 
99
  return model, tokenizer
100
 
101
  def obtain_chain_hf(
@@ -111,9 +134,17 @@ def obtain_chain_hf(
111
  )
112
  chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt])
113
  model, tokenizer = prepare_model(model_name)
114
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=max_tokens, temperature=temperature)
 
 
 
 
 
 
 
 
 
115
  hf = HuggingFacePipeline(pipeline=pipe)
116
- # import pdb; pdb.set_trace()
117
  chain = LLMChain(llm=hf, prompt=chat_prompt_template)
118
  return chain
119
 
@@ -123,7 +154,7 @@ def generate(
123
  input_values: dict[str, str],
124
  output_parser: BaseOutputParser[OutputType],
125
  temperature: float = 0.7,
126
- ) -> tuple[OutputType, str]:
127
  # import pdb; pdb.set_trace()
128
  input_variables = re.findall(r"{(.*?)}", template)
129
  assert (
@@ -135,8 +166,9 @@ def generate(
135
  chain = obtain_chain(model_name, template, input_variables, temperature)
136
  if "format_instructions" not in input_values:
137
  input_values["format_instructions"] = output_parser.get_format_instructions()
138
- result = chain.predict([], **input_values)
139
- # import pdb; pdb.set_trace()
 
140
  try:
141
  parsed_result = output_parser.parse(result)
142
  except KeyboardInterrupt:
@@ -146,6 +178,7 @@ def generate(
146
  f"[red] Failed to parse result: {result}\nEncounter Exception {e}\nstart to reparse",
147
  extra={"markup": True},
148
  )
 
149
  reformat_parsed_result = format_bad_output(
150
  result, format_instructions=output_parser.get_format_instructions()
151
  )
@@ -175,7 +208,7 @@ def format_bad_output(
175
  "ill_formed_output": ill_formed_output,
176
  "format_instructions": format_instructions,
177
  }
178
- reformat = chain.predict([], **input_values)
179
  log.info(f"Reformated output: {reformat}")
180
  return reformat
181
 
@@ -189,7 +222,7 @@ def obtain_chain(
189
  """
190
  Using langchain to sample profiles for participants
191
  """
192
- if model_name in ["cmu-lti/sotopia-pi-mistral-7b-BC_SR"]:
193
  return obtain_chain_hf(
194
  model_name=model_name,
195
  template=template,
@@ -212,32 +245,6 @@ def obtain_chain(
212
  chain = LLMChain(llm=chat, prompt=chat_prompt_template)
213
  return chain
214
 
215
- def format_bad_output(
216
- ill_formed_output: str,
217
- format_instructions: str,
218
- model_name: str = "gpt-3.5-turbo",
219
- ) -> str:
220
- template = """
221
- Given the string that can not be parsed by json parser, reformat it to a string that can be parsed by json parser.
222
- Original string: {ill_formed_output}
223
-
224
- Format instructions: {format_instructions}
225
-
226
- Please only generate the JSON:
227
- """
228
- chain = obtain_chain(
229
- model_name=model_name,
230
- template=template,
231
- input_variables=re.findall(r"{(.*?)}", template),
232
- )
233
- input_values = {
234
- "ill_formed_output": ill_formed_output,
235
- "format_instructions": format_instructions,
236
- }
237
- reformat = chain.predict([], **input_values)
238
- log.info(f"Reformated output: {reformat}")
239
- return reformat
240
-
241
  def _return_fixed_model_version(model_name: str) -> str:
242
  return {
243
  "gpt-3.5-turbo": "gpt-3.5-turbo-0613",
 
1
  import re
2
+ from typing import TypeVar
3
+ from functools import cache
4
+ import logging
5
 
6
  import torch
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
8
  from transformers import (
9
  AutoModelForCausalLM,
10
  AutoTokenizer,
11
  BitsAndBytesConfig,
12
  )
13
+ from peft import PeftModel
14
  from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
15
  from langchain_community.chat_models import ChatLiteLLM
 
 
16
  from langchain.chains import LLMChain
17
  from langchain.output_parsers import PydanticOutputParser
18
  from langchain.prompts import (
 
21
  PromptTemplate,
22
  )
23
  from langchain.schema import BaseOutputParser, OutputParserException
24
+ from message_classes import ActionType, AgentAction
25
+ from utils import format_docstring
26
 
27
+ from langchain_callback_handler import LoggingCallbackHandler
 
 
 
28
 
29
+ HF_TOKEN_KEY_FILE="./hf_token.key"
30
 
31
+ OutputType = TypeVar("OutputType", bound=object)
32
  log = logging.getLogger("generate")
33
+ logging_handler = LoggingCallbackHandler("langchain")
34
 
35
  def generate_action(
36
  model_name: str,
 
39
  action_types: list[ActionType],
40
  agent: str,
41
  temperature: float = 0.7,
42
+ ) -> AgentAction:
43
  """
44
  Using langchain to generate an example episode
45
  """
 
73
  temperature=temperature,
74
  )
75
  except Exception:
76
+ return AgentAction(action_type="none", argument="")
77
 
78
  @cache
79
+ def prepare_model(model_name, hf_token_key_file=HF_TOKEN_KEY_FILE):
80
  compute_type = torch.float16
81
+ with open (hf_token_key_file, 'r') as f:
82
+ hf_token = f.read().strip()
83
 
84
+ if model_name == 'cmu-lti/sotopia-pi-mistral-7b-BC_SR':
85
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", token=hf_token)
86
+ model = AutoModelForCausalLM.from_pretrained(
87
+ "mistralai/Mistral-7B-Instruct-v0.1",
88
+ cache_dir="./.cache",
89
+ device_map='cuda',
90
+ token=hf_token
91
+ )
92
+ model = PeftModel.from_pretrained(model, model_name).to("cuda")
93
+
94
+ elif model_name == 'cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit':
95
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", token=hf_token)
96
  model = AutoModelForCausalLM.from_pretrained(
97
  "mistralai/Mistral-7B-Instruct-v0.1",
98
  cache_dir="./.cache",
 
103
  bnb_4bit_quant_type="nf4",
104
  bnb_4bit_compute_dtype=compute_type,
105
  ),
106
+ token=hf_token
107
  )
108
  model = PeftModel.from_pretrained(model, model_name).to("cuda")
109
+
110
+ elif model_name == 'mistralai/Mistral-7B-Instruct-v0.1':
111
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", token=hf_token)
112
+ model = AutoModelForCausalLM.from_pretrained(
113
+ "mistralai/Mistral-7B-Instruct-v0.1",
114
+ cache_dir="./.cache",
115
+ device_map='cuda',
116
+ token=hf_token
117
+ )
118
+
119
  else:
120
  raise RuntimeError(f"Model {model_name} not supported")
121
+
122
  return model, tokenizer
123
 
124
  def obtain_chain_hf(
 
134
  )
135
  chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt])
136
  model, tokenizer = prepare_model(model_name)
137
+ pipe = pipeline("text-generation",
138
+ model=model,
139
+ tokenizer=tokenizer,
140
+ max_new_tokens=100,
141
+ temperature=temperature,
142
+ return_full_text=False,
143
+ do_sample=True,
144
+ num_beams=3,
145
+ length_penalty=-1.0,
146
+ )
147
  hf = HuggingFacePipeline(pipeline=pipe)
 
148
  chain = LLMChain(llm=hf, prompt=chat_prompt_template)
149
  return chain
150
 
 
154
  input_values: dict[str, str],
155
  output_parser: BaseOutputParser[OutputType],
156
  temperature: float = 0.7,
157
+ ) -> OutputType:
158
  # import pdb; pdb.set_trace()
159
  input_variables = re.findall(r"{(.*?)}", template)
160
  assert (
 
166
  chain = obtain_chain(model_name, template, input_variables, temperature)
167
  if "format_instructions" not in input_values:
168
  input_values["format_instructions"] = output_parser.get_format_instructions()
169
+ result = chain.predict([logging_handler], **input_values)
170
+ prompt = logging_handler.retrive_prompt()
171
+ import pdb; pdb.set_trace()
172
  try:
173
  parsed_result = output_parser.parse(result)
174
  except KeyboardInterrupt:
 
178
  f"[red] Failed to parse result: {result}\nEncounter Exception {e}\nstart to reparse",
179
  extra={"markup": True},
180
  )
181
+ import pdb; pdb.set_trace()
182
  reformat_parsed_result = format_bad_output(
183
  result, format_instructions=output_parser.get_format_instructions()
184
  )
 
208
  "ill_formed_output": ill_formed_output,
209
  "format_instructions": format_instructions,
210
  }
211
+ reformat = chain.predict([logging_handler], **input_values)
212
  log.info(f"Reformated output: {reformat}")
213
  return reformat
214
 
 
222
  """
223
  Using langchain to sample profiles for participants
224
  """
225
+ if model_name in ["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit"]:
226
  return obtain_chain_hf(
227
  model_name=model_name,
228
  template=template,
 
245
  chain = LLMChain(llm=chat, prompt=chat_prompt_template)
246
  return chain
247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  def _return_fixed_model_version(model_name: str) -> str:
249
  return {
250
  "gpt-3.5-turbo": "gpt-3.5-turbo-0613",
utils.py CHANGED
@@ -1,5 +1,6 @@
1
  from typing import List, Tuple
2
  import ast
 
3
 
4
  class Agent:
5
  def __init__(self, agent_profile):
@@ -31,80 +32,10 @@ class Environment:
31
  self.agent_goals = env_profile["agent_goals"]
32
  self.relationship = env_profile["relationship"]
33
 
34
- def get_format_guide():
35
- return """ Your available action types are
36
- "none action speak non-verbal communication leave".
37
- Note: You can "leave" this conversation if 1. you have achieved your social goals, 2. this conversation makes you uncomfortable, 3. you find it uninteresting/you lose your patience, 4. or for other reasons you want to leave.
38
-
39
- Please only generate a JSON string including the action type and the argument.
40
- Your action should follow the given format:
41
- \nAs an example, for the schema {\"properties\": {\"foo\": {\"title\": \"Foo\", \"description\": \"a list of strings\", \"type\": \"array\", \"items\": {\"type\": \"string\"}}}, \"required\": [\"foo\"]}
42
- the object {\"foo\": [\"bar\", \"baz\"]} is a well-formatted instance of the schema. The object {\"properties\": {\"foo\": [\"bar\", \"baz\"]}} is not well-formatted.
43
- \nHere is the output schema:\n```\n{\"description\": \"An interface for messages.\\nThere is only one required method: to_natural_language\", \"properties\": {\"action_type\": {\"title\": \"Action Type\", \"description\": \"whether to speak at this turn or choose to not do anything\", \"enum\": [\"none\", \"speak\", \"non-verbal communication\", \"action\", \"leave\"], \"type\": \"string\"}, \"argument\": {\"title\": \"Argument\", \"description\": \"the utterance if choose to speak, the expression or gesture if choose non-verbal communication, or the physical action if choose action\", \"type\": \"string\"}}, \"required\": [\"action_type\", \"argument\"]}\n```\u001b[0m
44
- """
45
-
46
- def get_starter_prompt(machine_agent, human_agent, environment):
47
- return f"Imagine you are {machine_agent.name}, your task is to act/speak as {machine_agent.name} would, keeping in mind {machine_agent.name}'s social goal.\nYou can find {machine_agent.name}'s background and goal in the 'Here is the context of the interaction' field.\nNote that {machine_agent.name}'s secret and goal is only visible to you.\nYou should try your best to achieve {machine_agent.name}'s goal in a way that align with their character traits.\nAdditionally, maintaining the conversation's naturalness and realism is essential (e.g., do not repeat what other people has already said before).\n\nHere is the context of this interaction:\n Scenario: {environment.scenario}\nParticipants: {human_agent.name} and {machine_agent.name}\n{human_agent.name}'s background: {human_agent.background} Personality and values description: {human_agent.personality} \n{machine_agent.name}'s background: {machine_agent.background} Personality and values description: {machine_agent.personality} {machine_agent.name}'s secrets: {machine_agent.secret}\n{human_agent.name}'s goal: Unknown\n{machine_agent.name}'s goal: {environment.agent_goals[1]}\nConversation Starts:"
48
-
49
  def get_context_prompt(machine_agent, human_agent, environment):
50
  return f"Here is the context of this interaction:\n Scenario: {environment.scenario}\nParticipants: {human_agent.name} and {machine_agent.name}\n{human_agent.name}'s background: {human_agent.background} Personality and values description: {human_agent.personality} \n{machine_agent.name}'s background: {machine_agent.background} Personality and values description: {machine_agent.personality} {machine_agent.name}'s secrets: {machine_agent.secret}\n{human_agent.name}'s goal: Unknown\n{machine_agent.name}'s goal: {environment.agent_goals[1]}\nConversation Starts:"
51
-
52
-
53
- # we define history as
54
- # [(user_message, bot_message), (user_message, bot_message)]
55
-
56
- # we define dialogue history as
57
- # user_name: user_message\nbot_name: bot_message\nuser_name: user_message\nbot_name: bot_message\n
58
-
59
-
60
- def dialogue_history_length_check(string, max_token, tokenizer):
61
- prompt_tokens = len(tokenizer(string)["input_ids"])
62
- return max(prompt_tokens - max_token, 0)
63
-
64
-
65
- def truncate_dialogue_history_to_length(dia_his, surpass_num, tokenizer):
66
- dia_sen = dia_his.split("\n")
67
- remove_len = 0
68
- i = 0
69
- while remove_len < surpass_num:
70
- remove_len += len(tokenizer(dia_sen[i])["input_ids"])
71
- i += 1
72
- trunc_dia = "\n".join(p for p in dia_sen[i:])
73
- return trunc_dia
74
-
75
-
76
- def format_bot_message(bot_message) -> str:
77
- # # import pdb; pdb.set_trace()
78
- start_idx, end_idx = bot_message.index("{"), bot_message.index("}")
79
- if end_idx == -1:
80
- bot_message += "'}"
81
- end_idx = len(bot_message)
82
- json_response = ast.literal_eval(bot_message[start_idx:end_idx+1])
83
- match json_response["action_type"]:
84
- case "none":
85
- return 'did nothing'
86
- case "speak":
87
- return json_response["argument"]
88
- case "non-verbal communication":
89
- return f'[{json_response["action_type"]}] {json_response["argument"]}'
90
- case "action":
91
- return f'[{json_response["action_type"]}] {json_response["argument"]}'
92
- case "leave":
93
- return 'left the conversation'
94
-
95
- def dialogue_history_creation(history, user_name, bot_name):
96
- dialogue_history = ""
97
- for idx, turn in enumerate(history):
98
- user_message, bot_message = turn
99
- # TODOTODO (haofeiyu): we first assume that human talks first
100
- user_turn_idx = idx * 2
101
- bot_turn_idx = idx * 2 + 1
102
- if not bot_message.startswith("["): # if action type == speak, need to add 'said: ' to be consistent with the dialog prompt
103
- bot_message = "said :" + bot_message
104
- dialogue_history = f"{dialogue_history}\n\nTurn #{user_turn_idx}: {user_name}: {user_message}\n\nTurn #{bot_turn_idx}: {bot_name}: {bot_message}"
105
- last_turn_idx = len(history) * 2
106
- return dialogue_history, last_turn_idx
107
-
108
  def dialogue_history_prompt(message, history, user_agent, bot_agent):
109
  dialogue_history = ""
110
  for idx, turn in enumerate(history):
@@ -117,31 +48,8 @@ def dialogue_history_prompt(message, history, user_agent, bot_agent):
117
  dialogue_history = f"{dialogue_history}\n\nTurn #{user_turn_idx}: {user_agent.name}: {user_message}\n\nTurn #{bot_turn_idx}: {bot_agent.name}: {bot_message}"
118
  last_turn_idx = len(history) * 2
119
  dialogue_history = f"{dialogue_history}\n\nTurn #{last_turn_idx+1}: {user_agent.name}: {message}\n."
120
- return dialogue_history, last_turn_idx+2
121
-
122
-
123
- def dialogue_history_truncation(dialogue_history, max_token_num, tokenizer):
124
- surpass_num = dialogue_history_length_check(
125
- dialogue_history, max_token_num, tokenizer
126
- )
127
- if surpass_num > 0:
128
- dialogue_history = truncate_dialogue_history_to_length(
129
- dialogue_history, surpass_num, tokenizer
130
- )
131
- return dialogue_history
132
-
133
 
134
- def format_hostory_prompt(
135
- message: str,
136
- history: List[Tuple[str, str]],
137
- instructions: str,
138
- user_name: str,
139
- bot_name: str,
140
- ) -> str:
141
- prompt = instructions.strip()
142
- dialogue_history, last_turn_idx = dialogue_history_creation(
143
- history, user_name, bot_name
144
- )
145
- prompt = f"{prompt}\n{dialogue_history}"
146
- prompt = f"{prompt}\n\nTurn #{last_turn_idx+1}: {user_name}: {message}\n.\nYou are at Turn #{last_turn_idx+2}."
147
- return prompt
 
1
  from typing import List, Tuple
2
  import ast
3
+ import re
4
 
5
  class Agent:
6
  def __init__(self, agent_profile):
 
32
  self.agent_goals = env_profile["agent_goals"]
33
  self.relationship = env_profile["relationship"]
34
 
35
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def get_context_prompt(machine_agent, human_agent, environment):
37
  return f"Here is the context of this interaction:\n Scenario: {environment.scenario}\nParticipants: {human_agent.name} and {machine_agent.name}\n{human_agent.name}'s background: {human_agent.background} Personality and values description: {human_agent.personality} \n{machine_agent.name}'s background: {machine_agent.background} Personality and values description: {machine_agent.personality} {machine_agent.name}'s secrets: {machine_agent.secret}\n{human_agent.name}'s goal: Unknown\n{machine_agent.name}'s goal: {environment.agent_goals[1]}\nConversation Starts:"
38
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def dialogue_history_prompt(message, history, user_agent, bot_agent):
40
  dialogue_history = ""
41
  for idx, turn in enumerate(history):
 
48
  dialogue_history = f"{dialogue_history}\n\nTurn #{user_turn_idx}: {user_agent.name}: {user_message}\n\nTurn #{bot_turn_idx}: {bot_agent.name}: {bot_message}"
49
  last_turn_idx = len(history) * 2
50
  dialogue_history = f"{dialogue_history}\n\nTurn #{last_turn_idx+1}: {user_agent.name}: {message}\n."
51
+ return dialogue_history, last_turn_idx + 2
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ def format_docstring(docstring: str) -> str:
54
+ """Format a docstring for use in a prompt template."""
55
+ return re.sub("\n +", "\n", docstring).strip()