Haofei Yu commited on
Commit
3ce130a
1 Parent(s): 6774d89

Feature/support multi turn (#14)

Browse files

* add the issue and pr template

* only show generated conversation

* support multi-turn sotopia prompt

Files changed (2) hide show
  1. app.py +24 -10
  2. utils.py +53 -12
app.py CHANGED
@@ -2,11 +2,12 @@ import gradio as gr
2
  from dataclasses import dataclass
3
  import os
4
  import torch
 
5
  from uuid import uuid4
6
  from peft import PeftModel, PeftConfig
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
8
 
9
- from utils import Agent, get_starter_prompt, format_chat_prompt
10
 
11
 
12
  HUMAN_AGENT = Agent(
@@ -23,18 +24,23 @@ MACHINE_AGENT = Agent(
23
  secrets="Descendant of a wealthy oil tycoon, rejects family fortune",
24
  personality="Benjamin Jackson, expressive and imaginative, leans towards self-direction and liberty. His decisions aim for societal betterment.",)
25
 
26
- DEFUALT_INSTRUCTIONS = get_starter_prompt(MACHINE_AGENT, HUMAN_AGENT, "Conversation between two friends, where one is upset and crying")
 
 
 
 
 
 
27
 
28
  DEPLOYED = os.getenv("DEPLOYED", "true").lower() == "true"
29
  MODEL_NAME = "cmu-lti/sotopia-pi-mistral-7b-BC_SR"
30
  COMPUTE_DTYPE = torch.float16
31
 
32
  config_dict = PeftConfig.from_json_file("peft_config.json")
33
- # import pdb; pdb.set_trace()
34
  config = PeftConfig.from_peft_type(**config_dict)
35
  tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
36
- model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
37
- model = PeftModel.from_pretrained(model, MODEL_NAME, config=config).to(COMPUTE_DTYPE).to("cuda")
38
  according_visible = True
39
 
40
 
@@ -109,10 +115,10 @@ def chat_accordion():
109
  max_lines=1,
110
  visible=False,
111
  )
112
-
113
  return temperature, instructions, user_name, bot_name, session_id, max_tokens
114
 
115
 
 
116
  def run_chat(
117
  message: str,
118
  history,
@@ -123,7 +129,13 @@ def run_chat(
123
  top_p: float,
124
  max_tokens: int
125
  ):
126
- prompt = format_chat_prompt(message, history, instructions, user_name, bot_name)
 
 
 
 
 
 
127
  input_tokens = tokenizer(prompt, return_tensors="pt", padding="do_not_pad").input_ids.to("cuda")
128
  input_length = input_tokens.shape[-1]
129
  output_tokens = model.generate(
@@ -138,7 +150,7 @@ def run_chat(
138
  text_output = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
139
  return text_output
140
 
141
-
142
  def chat_tab():
143
  with gr.Column():
144
  with gr.Row():
@@ -160,7 +172,10 @@ def chat_tab():
160
  render=False,
161
  show_label=False,
162
  rtl=False,
163
- avatar_images=("images/user_icon.png", "images/bot_icon.png"),
 
 
 
164
  ),
165
  textbox=gr.Textbox(
166
  placeholder="Write your message here...",
@@ -184,7 +199,6 @@ def chat_tab():
184
  )
185
 
186
 
187
-
188
  def main():
189
  with gr.Blocks(
190
  css="""#chat_container {height: 820px; width: 1000px; margin-left: auto; margin-right: auto;}
 
2
  from dataclasses import dataclass
3
  import os
4
  import torch
5
+ import transformers
6
  from uuid import uuid4
7
  from peft import PeftModel, PeftConfig
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
9
 
10
+ from utils import Agent, get_starter_prompt, format_sotopia_prompt
11
 
12
 
13
  HUMAN_AGENT = Agent(
 
24
  secrets="Descendant of a wealthy oil tycoon, rejects family fortune",
25
  personality="Benjamin Jackson, expressive and imaginative, leans towards self-direction and liberty. His decisions aim for societal betterment.",)
26
 
27
+ SCENARIO = "Conversation between two friends, where one is upset and crying"
28
+
29
+ DEFUALT_INSTRUCTIONS = get_starter_prompt(
30
+ MACHINE_AGENT,
31
+ HUMAN_AGENT,
32
+ SCENARIO
33
+ )
34
 
35
  DEPLOYED = os.getenv("DEPLOYED", "true").lower() == "true"
36
  MODEL_NAME = "cmu-lti/sotopia-pi-mistral-7b-BC_SR"
37
  COMPUTE_DTYPE = torch.float16
38
 
39
  config_dict = PeftConfig.from_json_file("peft_config.json")
 
40
  config = PeftConfig.from_peft_type(**config_dict)
41
  tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
42
+ model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1").to("cuda")
43
+ model = PeftModel.from_pretrained(model, MODEL_NAME, config=config).to("cuda")
44
  according_visible = True
45
 
46
 
 
115
  max_lines=1,
116
  visible=False,
117
  )
 
118
  return temperature, instructions, user_name, bot_name, session_id, max_tokens
119
 
120
 
121
+ # history are input output pairs
122
  def run_chat(
123
  message: str,
124
  history,
 
129
  top_p: float,
130
  max_tokens: int
131
  ):
132
+ prompt = format_sotopia_prompt(
133
+ message,
134
+ history,
135
+ instructions,
136
+ user_name,
137
+ bot_name
138
+ )
139
  input_tokens = tokenizer(prompt, return_tensors="pt", padding="do_not_pad").input_ids.to("cuda")
140
  input_length = input_tokens.shape[-1]
141
  output_tokens = model.generate(
 
150
  text_output = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
151
  return text_output
152
 
153
+
154
  def chat_tab():
155
  with gr.Column():
156
  with gr.Row():
 
172
  render=False,
173
  show_label=False,
174
  rtl=False,
175
+ avatar_images=(
176
+ "images/user_icon.png",
177
+ "images/bot_icon.png"
178
+ ),
179
  ),
180
  textbox=gr.Textbox(
181
  placeholder="Write your message here...",
 
199
  )
200
 
201
 
 
202
  def main():
203
  with gr.Blocks(
204
  css="""#chat_container {height: 820px; width: 1000px; margin-left: auto; margin-right: auto;}
utils.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  class Agent:
2
  def __init__(self, name, background, goal, secrets, personality):
3
  self.name = name
@@ -9,23 +11,62 @@ class Agent:
9
  def get_starter_prompt(machine_agent, human_agent, scenario):
10
  return f"Prompt after formatting:\nImagine you are {machine_agent.name}, your task is to act/speak as {machine_agent.name} would, keeping in mind {machine_agent.name}'s social goal.\nYou can find {machine_agent.name}'s background and goal in the 'Here is the context of the interaction' field.\nNote that {machine_agent.name}'s secret and goal is only visible to you.\nYou should try your best to achieve {machine_agent.name}'s goal in a way that align with their character traits.\nAdditionally, maintaining the conversation's naturalness and realism is essential (e.g., do not repeat what other people has already said before).\n\nHere is the context of this interaction:\n Scenario: {scenario}\nParticipants: {human_agent.name} and {machine_agent.name}\n{human_agent.name}'s background: {human_agent.background} Personality and values description: {human_agent.personality} \n{machine_agent.name}'s background: {machine_agent.background} Personality and values description: {machine_agent.personality} {machine_agent.name}'s secrets: {machine_agent.secrets}\n{human_agent.name}'s goal: Unknown\n{machine_agent.name}'s goal: {machine_agent.name}\nConversation Starts:"
11
 
12
- def format_chat_prompt(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  message: str,
14
- chat_history,
15
  instructions: str,
16
  user_name: str,
17
  bot_name: str,
18
  include_all_chat_history: bool = True,
19
  index : int = 1
20
  ) -> str:
21
- instructions = instructions.strip()
22
- prompt = instructions
23
- if not include_all_chat_history:
24
- if index >= 0:
25
- index = -index
26
- chat_history = chat_history[index:]
27
- for turn in chat_history:
28
- user_message, bot_message = turn
29
- prompt = f"{prompt}\n{user_name}: {user_message}\n{bot_name}: {bot_message}"
30
- prompt = f"{prompt}\n{user_name}: {message}\n{bot_name}:"
31
  return prompt
 
1
+ from typing import Tuple, List
2
+
3
  class Agent:
4
  def __init__(self, name, background, goal, secrets, personality):
5
  self.name = name
 
11
  def get_starter_prompt(machine_agent, human_agent, scenario):
12
  return f"Prompt after formatting:\nImagine you are {machine_agent.name}, your task is to act/speak as {machine_agent.name} would, keeping in mind {machine_agent.name}'s social goal.\nYou can find {machine_agent.name}'s background and goal in the 'Here is the context of the interaction' field.\nNote that {machine_agent.name}'s secret and goal is only visible to you.\nYou should try your best to achieve {machine_agent.name}'s goal in a way that align with their character traits.\nAdditionally, maintaining the conversation's naturalness and realism is essential (e.g., do not repeat what other people has already said before).\n\nHere is the context of this interaction:\n Scenario: {scenario}\nParticipants: {human_agent.name} and {machine_agent.name}\n{human_agent.name}'s background: {human_agent.background} Personality and values description: {human_agent.personality} \n{machine_agent.name}'s background: {machine_agent.background} Personality and values description: {machine_agent.personality} {machine_agent.name}'s secrets: {machine_agent.secrets}\n{human_agent.name}'s goal: Unknown\n{machine_agent.name}'s goal: {machine_agent.name}\nConversation Starts:"
13
 
14
+ # we define history as
15
+ # [(user_message, bot_message), (user_message, bot_message)]
16
+
17
+ # we define dialogue history as
18
+ # user_name: user_message\nbot_name: bot_message\nuser_name: user_message\nbot_name: bot_message\n
19
+
20
+ def dialogue_history_length_check(string, max_token, tokenizer):
21
+ prompt_tokens = len(tokenizer(string)["input_ids"])
22
+ return max(prompt_tokens - max_token, 0)
23
+
24
+
25
+ def truncate_dialogue_history_to_length(dia_his, surpass_num, tokenizer):
26
+ dia_sen = dia_his.split("\n")
27
+ remove_len = 0
28
+ i = 0
29
+ while remove_len < surpass_num:
30
+ remove_len += len(tokenizer(dia_sen[i])["input_ids"])
31
+ i += 1
32
+ trunc_dia = "\n".join(p for p in dia_sen[i:])
33
+ return trunc_dia
34
+
35
+
36
+ def dialogue_history_creation(history, user_name, bot_name):
37
+ dialogue_history = ""
38
+ for idx, turn in enumerate(history):
39
+ user_message, bot_message = turn
40
+ # TODOTODO (haofeiyu): we first assume that human talks first
41
+ user_turn_idx = idx * 2
42
+ bot_turn_idx = idx * 2 + 1
43
+ dialogue_history = f"{dialogue_history}\n\nTurn #{user_turn_idx}: {user_name}: {user_message}\n\nTurn #{bot_turn_idx}: {bot_name}: {bot_message}"
44
+ last_turn_idx = len(history) * 2
45
+ return dialogue_history, last_turn_idx
46
+
47
+
48
+ def dialogue_history_truncation(dialogue_history, max_token_num, tokenizer):
49
+ surpass_num = dialogue_history_length_check(dialogue_history, max_token_num, tokenizer)
50
+ if surpass_num > 0:
51
+ dialogue_history = truncate_dialogue_history_to_length(dialogue_history, surpass_num, tokenizer)
52
+ return dialogue_history
53
+
54
+
55
+ def format_sotopia_prompt(
56
  message: str,
57
+ history: List[Tuple[str, str]],
58
  instructions: str,
59
  user_name: str,
60
  bot_name: str,
61
  include_all_chat_history: bool = True,
62
  index : int = 1
63
  ) -> str:
64
+ prompt = instructions.strip()
65
+ dialogue_history, last_turn_idx = dialogue_history_creation(
66
+ history,
67
+ user_name,
68
+ bot_name
69
+ )
70
+ prompt = f"{prompt}\n{dialogue_history}"
71
+ prompt = f"{prompt}\n\nTurn #{last_turn_idx+1}: {user_name}: {message}\n.\nYou are at Turn #{last_turn_idx+2}."
 
 
72
  return prompt