Tristan Thrush commited on
Commit
0c08d16
2 Parent(s): bd15f33 8b9e466
Files changed (2) hide show
  1. README.md +15 -6
  2. app.py +48 -65
README.md CHANGED
@@ -9,11 +9,19 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- A basic example of an RLHF interface with a Gradio app.
13
 
14
- **Instructions for someone to use for their own project:**
15
 
16
- *Setting up the Space*
 
 
 
 
 
 
 
 
17
 
18
  1. Clone this repo and deploy it on your own Hugging Face space.
19
  2. Add the following secrets to your space:
@@ -38,7 +46,7 @@ python app.py
38
 
39
  The app will then be available at a local address, such as http://127.0.0.1:7860
40
 
41
- *Running Data Collection*
42
 
43
  1. On your local repo that you pulled, create a copy of `config.py.example`,
44
  just called `config.py`. Now, put keys from your AWS account in `config.py`.
@@ -47,11 +55,12 @@ The app will then be available at a local address, such as http://127.0.0.1:7860
47
  create an mturk requestor account associated with your AWS account.
48
  2. Run `python collect.py` locally.
49
 
50
- *Profit*
51
  Now, you should be watching hits come into your Hugging Face dataset
52
  automatically!
53
 
54
- *Tips and Tricks*
 
55
  - Use caution while doing local development of your space and
56
  simultaneously running it on mturk. Consider setting `FORCE_PUSH` to "no" in
57
  your local `.env` file.
 
9
  pinned: false
10
  ---
11
 
12
+ An RLHF interface for data collection with [Amazon Mechanical Turk](https://www.mturk.com) and Gradio.
13
 
14
+ ## Instructions for someone to use for their own project
15
 
16
+ ### Install dependencies
17
+
18
+ First, create a Python virtual environment and install the project's dependencies as follows:
19
+
20
+ ```bash
21
+ python -m pip install -r requirements.txt
22
+ ```
23
+
24
+ ### Setting up the Space
25
 
26
  1. Clone this repo and deploy it on your own Hugging Face space.
27
  2. Add the following secrets to your space:
 
46
 
47
  The app will then be available at a local address, such as http://127.0.0.1:7860
48
 
49
+ ### Running data collection*
50
 
51
  1. On your local repo that you pulled, create a copy of `config.py.example`,
52
  just called `config.py`. Now, put keys from your AWS account in `config.py`.
 
55
  create an mturk requestor account associated with your AWS account.
56
  2. Run `python collect.py` locally.
57
 
58
+ ### Profit
59
  Now, you should be watching hits come into your Hugging Face dataset
60
  automatically!
61
 
62
+ ### Tips and tricks
63
+
64
  - Use caution while doing local development of your space and
65
  simultaneously running it on mturk. Consider setting `FORCE_PUSH` to "no" in
66
  your local `.env` file.
app.py CHANGED
@@ -3,8 +3,11 @@
3
  import json
4
  import os
5
  import threading
 
6
  import uuid
 
7
  from pathlib import Path
 
8
  from urllib.parse import parse_qs
9
 
10
  import gradio as gr
@@ -17,15 +20,27 @@ from langchain.prompts import load_prompt
17
 
18
  from utils import force_git_push
19
 
20
- # These variables are for storing the mturk HITs in a Hugging Face dataset.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  if Path(".env").is_file():
22
  load_dotenv(".env")
23
  DATASET_REPO_URL = os.getenv("DATASET_REPO_URL")
24
  FORCE_PUSH = os.getenv("FORCE_PUSH")
25
  HF_TOKEN = os.getenv("HF_TOKEN")
26
  PROMPT_TEMPLATES = Path("prompt_templates")
27
- # Set env variable for langchain to communicate with Hugging Face Hub
28
- os.environ["HUGGINGFACEHUB_API_TOKEN"] = HF_TOKEN
29
 
30
  DATA_FILENAME = "data.jsonl"
31
  DATA_FILE = os.path.join("data", DATA_FILENAME)
@@ -58,52 +73,24 @@ asynchronous_push(f_stop)
58
  # Now let's run the app!
59
  prompt = load_prompt(PROMPT_TEMPLATES / "openai_chatgpt.json")
60
 
61
- chatbot_1 = ConversationChain(
62
- llm=HuggingFaceHub(
63
- repo_id="google/flan-t5-xl",
64
- model_kwargs={"temperature": 1}
65
- ),
66
- prompt=prompt,
67
- verbose=False,
68
- memory=ConversationBufferMemory(ai_prefix="Assistant"),
69
- )
70
-
71
- chatbot_2 = ConversationChain(
72
- llm=HuggingFaceHub(
73
- repo_id="bigscience/bloom",
74
- model_kwargs={"temperature": 0.7}
75
- ),
76
- prompt=prompt,
77
- verbose=False,
78
- memory=ConversationBufferMemory(ai_prefix="Assistant"),
79
- )
80
 
81
- chatbot_3 = ConversationChain(
 
82
  llm=HuggingFaceHub(
83
- repo_id="bigscience/T0_3B",
84
- model_kwargs={"temperature": 1}
 
85
  ),
86
  prompt=prompt,
87
  verbose=False,
88
  memory=ConversationBufferMemory(ai_prefix="Assistant"),
89
- )
90
 
91
- chatbot_4 = ConversationChain(
92
- llm=HuggingFaceHub(
93
- repo_id="EleutherAI/gpt-j-6B",
94
- model_kwargs={"temperature": 1}
95
- ),
96
- prompt=prompt,
97
- verbose=False,
98
- memory=ConversationBufferMemory(ai_prefix="Assistant"),
99
- )
100
 
101
- model_id2model = {
102
- "google/flan-t5-xl": chatbot_1,
103
- "bigscience/bloom": chatbot_2,
104
- "bigscience/T0_3B": chatbot_3,
105
- "EleutherAI/gpt-j-6B": chatbot_4
106
- }
107
 
108
  demo = gr.Blocks()
109
 
@@ -117,11 +104,9 @@ with demo:
117
  "cnt": 0, "data": [],
118
  "past_user_inputs": [],
119
  "generated_responses": [],
120
- "response_1": "",
121
- "response_2": "",
122
- "response_3": "",
123
- "response_4": "",
124
  }
 
 
125
  state = gr.JSON(state_dict, visible=False)
126
 
127
  gr.Markdown("# RLHF Interface")
@@ -131,27 +116,29 @@ with demo:
131
 
132
  # Generate model prediction
133
  def _predict(txt, state):
134
- # TODO: parallelize this!
135
- response_1 = chatbot_1.predict(input=txt)
136
- response_2 = chatbot_2.predict(input=txt)
137
- response_3 = chatbot_3.predict(input=txt)
138
- response_4 = chatbot_4.predict(input=txt)
139
 
140
  response2model_id = {}
141
- response2model_id[response_1] = chatbot_1.llm.repo_id
142
- response2model_id[response_2] = chatbot_2.llm.repo_id
143
- response2model_id[response_3] = chatbot_3.llm.repo_id
144
- response2model_id[response_4] = chatbot_4.llm.repo_id
145
 
146
  state["cnt"] += 1
147
 
148
  new_state_md = f"Inputs remaining in HIT: {state['cnt']}/{TOTAL_CNT}"
149
 
150
- state["data"].append({"cnt": state["cnt"], "text": txt, "response_1": response_1, "response_2": response_2, "response_3": response_3, "response_4": response_4,"response2model_id": response2model_id})
 
 
 
 
 
 
151
  state["past_user_inputs"].append(txt)
152
 
153
  past_conversation_string = "<br />".join(["<br />".join(["😃: " + user_input, "🤖: " + model_response]) for user_input, model_response in zip(state["past_user_inputs"], state["generated_responses"] + [""])])
154
- return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True, choices=[response_1, response_2, response_3, response_4], interactive=True, value=response_1), gr.update(value=past_conversation_string), state, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), new_state_md, dummy
155
 
156
  def _select_response(selected_response, state, dummy):
157
  done = state["cnt"] == TOTAL_CNT
@@ -182,17 +169,13 @@ with demo:
182
 
183
  if done:
184
  # Wipe the memory completely because we will be starting a new hit soon.
185
- chatbot_1.memory = ConversationBufferMemory(ai_prefix="Assistant")
186
- chatbot_2.memory = ConversationBufferMemory(ai_prefix="Assistant")
187
- chatbot_3.memory = ConversationBufferMemory(ai_prefix="Assistant")
188
- chatbot_4.memory = ConversationBufferMemory(ai_prefix="Assistant")
189
  else:
190
  # Sync all of the model's memories with the conversation path that
191
  # was actually taken.
192
- chatbot_1.memory = model_id2model[state["data"][-1]["response2model_id"][selected_response]].memory
193
- chatbot_2.memory = model_id2model[state["data"][-1]["response2model_id"][selected_response]].memory
194
- chatbot_3.memory = model_id2model[state["data"][-1]["response2model_id"][selected_response]].memory
195
- chatbot_4.memory = model_id2model[state["data"][-1]["response2model_id"][selected_response]].memory
196
 
197
  text_input = gr.update(visible=False) if done else gr.update(visible=True)
198
  return gr.update(visible=False), gr.update(visible=True), text_input, gr.update(visible=False), state, gr.update(value=past_conversation_string), toggle_example_submit, toggle_final_submit, toggle_final_submit_preview,
@@ -207,7 +190,7 @@ with demo:
207
  with gr.Column(visible=False) as final_submit:
208
  submit_hit_button = gr.Button("Submit HIT")
209
  with gr.Column(visible=False) as final_submit_preview:
210
- submit_hit_button_preview = gr.Button("Submit Work (preview mode; no mturk HIT credit, but your examples will still be stored)")
211
 
212
  # Button event handlers
213
  get_window_location_search_js = """
 
3
  import json
4
  import os
5
  import threading
6
+ import time
7
  import uuid
8
+ from concurrent.futures import ThreadPoolExecutor
9
  from pathlib import Path
10
+ from typing import List
11
  from urllib.parse import parse_qs
12
 
13
  import gradio as gr
 
20
 
21
  from utils import force_git_push
22
 
23
+
24
+ def generate_respone(chatbot: ConversationChain, input: str) -> str:
25
+ """Generates a response for a `langchain` chatbot."""
26
+ return chatbot.predict(input=input)
27
+
28
+ def generate_responses(chatbots: List[ConversationChain], inputs: List[str]) -> List[str]:
29
+ """Generates parallel responses for a list of `langchain` chatbots."""
30
+ results = []
31
+ with ThreadPoolExecutor(max_workers=100) as executor:
32
+ for result in executor.map(generate_respone, chatbots, inputs):
33
+ results.append(result)
34
+ return results
35
+
36
+
37
+ # These variables are for storing the MTurk HITs in a Hugging Face dataset.
38
  if Path(".env").is_file():
39
  load_dotenv(".env")
40
  DATASET_REPO_URL = os.getenv("DATASET_REPO_URL")
41
  FORCE_PUSH = os.getenv("FORCE_PUSH")
42
  HF_TOKEN = os.getenv("HF_TOKEN")
43
  PROMPT_TEMPLATES = Path("prompt_templates")
 
 
44
 
45
  DATA_FILENAME = "data.jsonl"
46
  DATA_FILE = os.path.join("data", DATA_FILENAME)
 
73
  # Now let's run the app!
74
  prompt = load_prompt(PROMPT_TEMPLATES / "openai_chatgpt.json")
75
 
76
+ # TODO: update this list with better, instruction-trained models
77
+ MODEL_IDS = ["google/flan-t5-xl", "bigscience/T0_3B", "EleutherAI/gpt-j-6B"]
78
+ chatbots = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ for model_id in MODEL_IDS:
81
+ chatbots.append(ConversationChain(
82
  llm=HuggingFaceHub(
83
+ repo_id=model_id,
84
+ model_kwargs={"temperature": 1},
85
+ huggingfacehub_api_token=HF_TOKEN,
86
  ),
87
  prompt=prompt,
88
  verbose=False,
89
  memory=ConversationBufferMemory(ai_prefix="Assistant"),
90
+ ))
91
 
 
 
 
 
 
 
 
 
 
92
 
93
+ model_id2model = {chatbot.llm.repo_id: chatbot for chatbot in chatbots}
 
 
 
 
 
94
 
95
  demo = gr.Blocks()
96
 
 
104
  "cnt": 0, "data": [],
105
  "past_user_inputs": [],
106
  "generated_responses": [],
 
 
 
 
107
  }
108
+ for idx in range(len(chatbots)):
109
+ state_dict[f"response_{idx+1}"] = ""
110
  state = gr.JSON(state_dict, visible=False)
111
 
112
  gr.Markdown("# RLHF Interface")
 
116
 
117
  # Generate model prediction
118
  def _predict(txt, state):
119
+ start = time.time()
120
+ responses = generate_responses(chatbots, [txt] * len(chatbots))
121
+ print(f"Time taken to generate {len(chatbots)} responses : {time.time() - start:.2f} seconds")
 
 
122
 
123
  response2model_id = {}
124
+ for chatbot, response in zip(chatbots, responses):
125
+ response2model_id[response] = chatbot.llm.repo_id
 
 
126
 
127
  state["cnt"] += 1
128
 
129
  new_state_md = f"Inputs remaining in HIT: {state['cnt']}/{TOTAL_CNT}"
130
 
131
+ metadata = {"cnt": state["cnt"], "text": txt}
132
+ for idx, response in enumerate(responses):
133
+ metadata[f"response_{idx + 1}"] = response
134
+
135
+ metadata["response2model_id"] = response2model_id
136
+
137
+ state["data"].append(metadata)
138
  state["past_user_inputs"].append(txt)
139
 
140
  past_conversation_string = "<br />".join(["<br />".join(["😃: " + user_input, "🤖: " + model_response]) for user_input, model_response in zip(state["past_user_inputs"], state["generated_responses"] + [""])])
141
+ return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True, choices=responses, interactive=True, value=responses[0]), gr.update(value=past_conversation_string), state, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), new_state_md, dummy
142
 
143
  def _select_response(selected_response, state, dummy):
144
  done = state["cnt"] == TOTAL_CNT
 
169
 
170
  if done:
171
  # Wipe the memory completely because we will be starting a new hit soon.
172
+ for chatbot in chatbots:
173
+ chatbot.memory = ConversationBufferMemory(ai_prefix="Assistant")
 
 
174
  else:
175
  # Sync all of the model's memories with the conversation path that
176
  # was actually taken.
177
+ for chatbot in chatbots:
178
+ chatbot.memory = model_id2model[state["data"][-1]["response2model_id"][selected_response]].memory
 
 
179
 
180
  text_input = gr.update(visible=False) if done else gr.update(visible=True)
181
  return gr.update(visible=False), gr.update(visible=True), text_input, gr.update(visible=False), state, gr.update(value=past_conversation_string), toggle_example_submit, toggle_final_submit, toggle_final_submit_preview,
 
190
  with gr.Column(visible=False) as final_submit:
191
  submit_hit_button = gr.Button("Submit HIT")
192
  with gr.Column(visible=False) as final_submit_preview:
193
+ submit_hit_button_preview = gr.Button("Submit Work (preview mode; no MTurk HIT credit, but your examples will still be stored)")
194
 
195
  # Button event handlers
196
  get_window_location_search_js = """