Ben Burtenshaw commited on
Commit
142be7a
·
1 Parent(s): 5776d7d
pages/2_👩🏼‍🔬 Describe Domain.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import streamlit as st
4
+
5
+ from hub import push_dataset_to_hub, pull_seed_data_from_repo
6
+ from infer import query
7
+ from defaults import (
8
+ N_PERSPECTIVES,
9
+ N_TOPICS,
10
+ SEED_DATA_PATH,
11
+ PIPELINE_PATH,
12
+ DATASET_REPO_ID,
13
+ )
14
+ from utils import project_sidebar, create_seed_terms, create_application_instruction
15
+
16
+
17
+ st.set_page_config(
18
+ page_title="Domain Data Grower",
19
+ page_icon="🧑‍🌾",
20
+ )
21
+ project_sidebar()
22
+
23
+
24
+ ################################################################################
25
+ # HEADER
26
+ ################################################################################
27
+
28
+ st.header("🧑‍🌾 Domain Data Grower")
29
+ st.divider()
30
+ st.subheader(
31
+ "Step 2. Define the specific domain that you want to generate synthetic data for.",
32
+ )
33
+ st.write(
34
+ "Define the project details, including the project name, domain, and API credentials"
35
+ )
36
+
37
+
38
+ ################################################################################
39
+ # LOAD EXISTING DOMAIN DATA
40
+ ################################################################################
41
+
42
+ DATASET_REPO_ID = (
43
+ f"{st.session_state['hub_username']}/{st.session_state['project_name']}"
44
+ )
45
+ SEED_DATA = pull_seed_data_from_repo(
46
+ DATASET_REPO_ID, hub_token=st.session_state["hub_token"]
47
+ )
48
+ DEFAULT_DOMAIN = SEED_DATA.get("domain", "")
49
+ DEFAULT_PERSPECTIVES = SEED_DATA.get("perspectives", [""])
50
+ DEFAULT_TOPICS = SEED_DATA.get("topics", [""])
51
+ DEFAULT_EXAMPLES = SEED_DATA.get("examples", [{"question": "", "answer": ""}])
52
+ DEFAULT_SYSTEM_PROMPT = SEED_DATA.get("domain_expert_prompt", "")
53
+
54
+ ################################################################################
55
+ # Domain Expert Section
56
+ ################################################################################
57
+
58
+ (
59
+ tab_domain_expert,
60
+ tab_domain_perspectives,
61
+ tab_domain_topics,
62
+ tab_examples,
63
+ tab_raw_seed,
64
+ ) = st.tabs(
65
+ tabs=[
66
+ "👩🏼‍🔬 Domain Expert",
67
+ "🔍 Domain Perspectives",
68
+ "🕸️ Domain Topics",
69
+ "📚 Examples",
70
+ "🌱 Raw Seed Data",
71
+ ]
72
+ )
73
+
74
+ with tab_domain_expert:
75
+ st.text("Define the domain expertise that you want to train a language model")
76
+ st.info(
77
+ "A domain expert is a person who is an expert in a particular field or area. For example, a domain expert in farming would be someone who has extensive knowledge and experience in farming and agriculture."
78
+ )
79
+
80
+ domain = st.text_input("Domain Name", DEFAULT_DOMAIN)
81
+
82
+ domain_expert_prompt = st.text_area(
83
+ label="Domain Expert Definition",
84
+ value=DEFAULT_SYSTEM_PROMPT,
85
+ height=200,
86
+ )
87
+
88
+ ################################################################################
89
+ # Domain Perspectives
90
+ ################################################################################
91
+
92
+ with tab_domain_perspectives:
93
+ st.text("Define the different perspectives from which the domain can be viewed")
94
+ st.info(
95
+ """
96
+ Perspectives are different viewpoints or angles from which a domain can be viewed.
97
+ For example, the domain of farming can be viewed from the perspective of a commercial
98
+ farmer or an independent family farmer."""
99
+ )
100
+
101
+ perspectives = st.session_state.get(
102
+ "perspectives",
103
+ [DEFAULT_PERSPECTIVES[0]],
104
+ )
105
+ perspectives_container = st.container()
106
+
107
+ perspectives = [
108
+ perspectives_container.text_input(
109
+ f"Domain Perspective {i + 1}", value=perspective
110
+ )
111
+ for i, perspective in enumerate(perspectives)
112
+ ]
113
+
114
+ if st.button("Add Perspective", key="add_perspective"):
115
+ n = len(perspectives)
116
+ value = DEFAULT_PERSPECTIVES[n] if n < N_PERSPECTIVES else ""
117
+ perspectives.append(
118
+ perspectives_container.text_input(f"Domain Perspective {n + 1}", value="")
119
+ )
120
+
121
+ st.session_state["perspectives"] = perspectives
122
+
123
+
124
+ ################################################################################
125
+ # Domain Topics
126
+ ################################################################################
127
+
128
+ with tab_domain_topics:
129
+ st.text("Define the main themes or subjects that are relevant to the domain")
130
+ st.info(
131
+ """Topics are the main themes or subjects that are relevant to the domain. For example, the domain of farming can have topics like soil health, crop rotation, or livestock management."""
132
+ )
133
+ topics = st.session_state.get(
134
+ "topics",
135
+ [DEFAULT_TOPICS[0]],
136
+ )
137
+ topics_container = st.container()
138
+ topics = [
139
+ topics_container.text_input(f"Domain Topic {i + 1}", value=topic)
140
+ for i, topic in enumerate(topics)
141
+ ]
142
+
143
+ if st.button("Add Topic", key="add_topic"):
144
+ n = len(topics)
145
+ value = DEFAULT_TOPICS[n] if n < N_TOPICS else ""
146
+ topics.append(topics_container.text_input(f"Domain Topics {n + 1}", value=""))
147
+
148
+ st.session_state["topics"] = topics
149
+
150
+
151
+ ################################################################################
152
+ # Examples Section
153
+ ################################################################################
154
+
155
+ with tab_examples:
156
+ st.text(
157
+ "Add high-quality questions and answers that can be used to generate synthetic data"
158
+ )
159
+ st.info(
160
+ """
161
+ Examples are high-quality questions and answers that can be used to generate
162
+ synthetic data for the domain. These examples will be used to train the language model
163
+ to generate questions and answers.
164
+ """
165
+ )
166
+
167
+ examples = st.session_state.get(
168
+ "examples",
169
+ [
170
+ {
171
+ "question": "",
172
+ "answer": "",
173
+ }
174
+ ],
175
+ )
176
+
177
+ for n, example in enumerate(examples, 1):
178
+ question = example["question"]
179
+ answer = example["answer"]
180
+ examples_container = st.container()
181
+ question_column, answer_column = examples_container.columns(2)
182
+
183
+ if st.button(f"Generate Answer {n}"):
184
+ if st.session_state["hub_token"] is None:
185
+ st.error("Please provide a Hub token to generate answers")
186
+ else:
187
+ answer = query(question, st.session_state["hub_token"])
188
+ with question_column:
189
+ question = st.text_area(f"Question {n}", value=question)
190
+
191
+ with answer_column:
192
+ answer = st.text_area(f"Answer {n}", value=answer)
193
+ examples[n - 1] = {"question": question, "answer": answer}
194
+ st.session_state["examples"] = examples
195
+ st.divider()
196
+
197
+ if st.button("Add Example"):
198
+ examples.append({"question": "", "answer": ""})
199
+ st.session_state["examples"] = examples
200
+ st.rerun()
201
+
202
+ ################################################################################
203
+ # Save Domain Data
204
+ ################################################################################
205
+
206
+ perspectives = list(filter(None, perspectives))
207
+ topics = list(filter(None, topics))
208
+
209
+ domain_data = {
210
+ "domain": domain,
211
+ "perspectives": perspectives,
212
+ "topics": topics,
213
+ "examples": examples,
214
+ "domain_expert_prompt": domain_expert_prompt,
215
+ "application_instruction": create_application_instruction(domain, examples),
216
+ "seed_terms": create_seed_terms(topics, perspectives),
217
+ }
218
+
219
+ with open(SEED_DATA_PATH, "w") as f:
220
+ json.dump(domain_data, f, indent=2)
221
+
222
+ with tab_raw_seed:
223
+ st.code(json.dumps(domain_data, indent=2), language="json", line_numbers=True)
224
+
225
+ ################################################################################
226
+ # Setup Dataset on the Hub
227
+ ################################################################################
228
+
229
+ st.divider()
230
+
231
+
232
+ if st.button("🤗 Push Dataset Seed") and all(
233
+ (
234
+ domain,
235
+ domain_expert_prompt,
236
+ perspectives,
237
+ topics,
238
+ examples,
239
+ )
240
+ ):
241
+ if all(
242
+ (
243
+ st.session_state.get("project_name"),
244
+ st.session_state.get("hub_username"),
245
+ st.session_state.get("hub_token"),
246
+ )
247
+ ):
248
+ project_name = st.session_state["project_name"]
249
+ hub_username = st.session_state["hub_username"]
250
+ hub_token = st.session_state["hub_token"]
251
+ else:
252
+ st.error(
253
+ "Please create a dataset repo on the Hub before pushing the dataset seed"
254
+ )
255
+ st.stop()
256
+
257
+ push_dataset_to_hub(
258
+ domain_seed_data_path=SEED_DATA_PATH,
259
+ project_name=project_name,
260
+ domain=domain,
261
+ hub_username=hub_username,
262
+ hub_token=hub_token,
263
+ pipeline_path=PIPELINE_PATH,
264
+ )
265
+
266
+ st.success(
267
+ f"Dataset seed created and pushed to the Hub. Check it out [here](https://huggingface.co/datasets/{hub_username}/{project_name})"
268
+ )
269
+
270
+ st.write("You can now move on to runnning your distilabel pipeline.")
271
+
272
+ st.page_link(
273
+ page="pages/3_🌱 Generate Dataset.py",
274
+ label="Generate Dataset",
275
+ icon="🌱",
276
+ )
277
+
278
+ else:
279
+ st.info(
280
+ "Please fill in all the required domain fields to push the dataset seed to the Hub"
281
+ )
pages/3_🌱 Generate Dataset.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from defaults import ARGILLA_URL
4
+ from hub import push_pipeline_params
5
+ from utils import project_sidebar
6
+
7
+ st.set_page_config(
8
+ page_title="Domain Data Grower",
9
+ page_icon="🧑‍🌾",
10
+ )
11
+
12
+ project_sidebar()
13
+
14
+ ################################################################################
15
+ # HEADER
16
+ ################################################################################
17
+
18
+ st.header("🧑‍🌾 Domain Data Grower")
19
+ st.divider()
20
+ st.subheader("Step 3. Run the pipeline to generate synthetic data")
21
+ st.write("Define the distilabel pipeline for generating the dataset.")
22
+
23
+ hub_username = st.session_state.get("hub_username")
24
+ project_name = st.session_state.get("project_name")
25
+ hub_token = st.session_state.get("hub_token")
26
+
27
+ ###############################################################
28
+ # CONFIGURATION
29
+ ###############################################################
30
+
31
+ st.divider()
32
+
33
+ st.markdown("## 🧰 Pipeline Configuration")
34
+
35
+ st.write(
36
+ "Now we need to define the configuration for the pipeline that will generate the synthetic data."
37
+ )
38
+ st.write(
39
+ "⚠️ Model and parameter choices significantly affect the quality of the generated data. \
40
+ We reccomend that you start with generating a few samples and review the data. Then scale up from there. \
41
+ You can run the pipeline multiple times with different configurations and append it to the same Argilla dataset."
42
+ )
43
+
44
+
45
+ st.markdown("#### 🤖 Inference configuration")
46
+
47
+ st.write(
48
+ "Add the url of the Huggingface inference API or endpoint that your pipeline should use. You can find compatible models here:"
49
+ )
50
+
51
+ with st.expander("🤗 Recommended Models"):
52
+ st.write("All inference endpoint compatible models can be found via the link below")
53
+ st.link_button(
54
+ "🤗 Inference compaptible models on the hub",
55
+ "https://huggingface.co/models?pipeline_tag=text-generation&other=endpoints_compatible&sort=trending",
56
+ )
57
+ st.write("🔋Projects with sufficient resources could take advantage of LLama3 70b")
58
+ st.code(
59
+ "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct"
60
+ )
61
+
62
+ st.write("🪫Projects with less resources could take advantage of LLama 3 8b")
63
+ st.code(
64
+ "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct"
65
+ )
66
+
67
+ st.write("🍃Projects with even less resources could use Phi-3-mini-4k-instruct")
68
+ st.code(
69
+ "https://api-inference.huggingface.co/models/microsoft/Phi-3-mini-4k-instruct"
70
+ )
71
+
72
+ st.write("Note Hugggingface Pro gives access to more compute resources")
73
+ st.link_button(
74
+ "🤗 Huggingface Pro",
75
+ "https://huggingface.co/pricing",
76
+ )
77
+
78
+
79
+ self_instruct_base_url = st.text_input(
80
+ label="Model base URL for instruction generation",
81
+ value="https://api-inference.huggingface.co/models/microsoft/Phi-3-mini-4k-instruct",
82
+ )
83
+ domain_expert_base_url = st.text_input(
84
+ label="Model base URL for domain expert response",
85
+ value="https://api-inference.huggingface.co/models/microsoft/Phi-3-mini-4k-instruct",
86
+ )
87
+
88
+ st.divider()
89
+ st.markdown("#### 🧮 Parameters configuration")
90
+
91
+ self_intruct_num_generations = st.slider(
92
+ "Number of generations for self-instruction", 1, 10, 2
93
+ )
94
+ domain_expert_num_generations = st.slider(
95
+ "Number of generations for domain expert response", 1, 10, 2
96
+ )
97
+ self_instruct_temperature = st.slider("Temperature for self-instruction", 0.1, 1.0, 0.9)
98
+ domain_expert_temperature = st.slider("Temperature for domain expert", 0.1, 1.0, 0.9)
99
+
100
+ st.divider()
101
+ st.markdown("#### 🔬 Argilla API details to push the generated dataset")
102
+ argilla_url = st.text_input("Argilla API URL", ARGILLA_URL)
103
+ argilla_api_key = st.text_input("Argilla API Key", "owner.apikey")
104
+ argilla_dataset_name = st.text_input("Argilla Dataset Name", project_name)
105
+ st.divider()
106
+
107
+ ###############################################################
108
+ # LOCAL
109
+ ###############################################################
110
+
111
+ st.markdown("## Run the pipeline")
112
+
113
+ st.markdown(
114
+ "Once you've defined the pipeline configuration above, you can run the pipeline from your local machine."
115
+ )
116
+
117
+
118
+ if all(
119
+ [
120
+ argilla_api_key,
121
+ argilla_url,
122
+ self_instruct_base_url,
123
+ domain_expert_base_url,
124
+ self_intruct_num_generations,
125
+ domain_expert_num_generations,
126
+ self_instruct_temperature,
127
+ domain_expert_temperature,
128
+ hub_username,
129
+ project_name,
130
+ hub_token,
131
+ argilla_dataset_name,
132
+ ]
133
+ ) and st.button("💾 Save Pipeline Config"):
134
+ with st.spinner("Pushing pipeline to the Hub..."):
135
+ push_pipeline_params(
136
+ pipeline_params={
137
+ "argilla_api_key": argilla_api_key,
138
+ "argilla_api_url": argilla_url,
139
+ "argilla_dataset_name": argilla_dataset_name,
140
+ "self_instruct_base_url": self_instruct_base_url,
141
+ "domain_expert_base_url": domain_expert_base_url,
142
+ "self_instruct_temperature": self_instruct_temperature,
143
+ "domain_expert_temperature": domain_expert_temperature,
144
+ "self_intruct_num_generations": self_intruct_num_generations,
145
+ "domain_expert_num_generations": domain_expert_num_generations,
146
+ },
147
+ hub_username=hub_username,
148
+ hub_token=hub_token,
149
+ project_name=project_name,
150
+ )
151
+
152
+ st.success(
153
+ f"Pipeline configuration pushed to the dataset repo {hub_username}/{project_name} on the Hub."
154
+ )
155
+
156
+ st.markdown(
157
+ "To run the pipeline locally, you need to have the `distilabel` library installed. You can install it using the following command:"
158
+ )
159
+
160
+ st.code(
161
+ f"""
162
+
163
+ # Install the distilabel library
164
+ pip install distilabel
165
+ """
166
+ )
167
+
168
+ st.markdown("Next, you'll need to clone your dataset repo and run the pipeline:")
169
+
170
+ st.code(
171
+ f"""
172
+ git clone https://github.com/huggingface/data-is-better-together
173
+ cd data-is-better-together/domain-specific-datasets/pipelines
174
+ pip install -r requirements.txt
175
+ """
176
+ )
177
+
178
+ st.markdown("Finally, you can run the pipeline using the following command:")
179
+
180
+ st.code(
181
+ f"""
182
+ huggingface-cli login
183
+ python domain_expert_pipeline.py {hub_username}/{project_name}""",
184
+ language="bash",
185
+ )
186
+ st.markdown(
187
+ "👩‍🚀 If you want to customise the pipeline take a look in `pipeline.py` and teh [distilabel docs](https://distilabel.argilla.io/)"
188
+ )
189
+
190
+ st.markdown(
191
+ "🚀 Once you've run the pipeline your records will be available in the Argilla space"
192
+ )
193
+
194
+ st.link_button("🔗 Argilla Space", argilla_url)
195
+
196
+ st.markdown("Once you've reviewed the data, you can publish it on the next page:")
197
+
198
+ st.page_link(
199
+ page="pages/4_🔍 Review Generated Data.py",
200
+ label="Review Generated Data",
201
+ icon="🔍",
202
+ )
203
+
204
+ else:
205
+ st.info("Please fill all the required fields.")
pages/4_🔍 Review Generated Data.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from defaults import PROJECT_NAME, ARGILLA_URL, DATASET_REPO_ID
4
+ from utils import project_sidebar
5
+ from hub import push_argilla_dataset_to_hub
6
+
7
+ st.set_page_config(
8
+ page_title="Domain Data Grower",
9
+ page_icon="🧑‍🌾",
10
+ )
11
+
12
+ project_sidebar()
13
+
14
+ ################################################################################
15
+ # HEADER
16
+ ################################################################################
17
+
18
+ st.header("🧑‍🌾 Domain Data Grower")
19
+ st.divider()
20
+
21
+ st.write(
22
+ """Once you have reviewed the synthetic data in Argilla, you can publish the
23
+ generated dataset to the Hub."""
24
+ )
25
+
26
+
27
+ ################################################################################
28
+ # Configuration
29
+ ################################################################################
30
+
31
+ st.divider()
32
+ st.write("🔬 Argilla API details to push the generated dataset")
33
+ argilla_url = st.text_input("Argilla API URL", ARGILLA_URL)
34
+ argilla_api_key = st.text_input("Argilla API Key", "owner.apikey")
35
+ argilla_dataset_name = st.text_input("Argilla Dataset Name", PROJECT_NAME)
36
+ dataset_repo_id = st.text_input("Dataset Repo ID", DATASET_REPO_ID)
37
+ st.divider()
38
+
39
+ if st.button("🚀 Publish the generated dataset"):
40
+ with st.spinner("Publishing the generated dataset..."):
41
+ push_argilla_dataset_to_hub(
42
+ name=argilla_dataset_name,
43
+ repo_id=dataset_repo_id,
44
+ url=argilla_url,
45
+ api_key=argilla_api_key,
46
+ workspace="admin",
47
+ )
48
+ st.success("The generated dataset has been published to the Hub.")