disticleaner / pipeline.py
Ben Burtenshaw
first commit
39e6ae5
raw
history blame contribute delete
No virus
2.67 kB
import os
import json
import sys
from distilabel.llms import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import (
LoadDataFromDicts,
TextGenerationToArgilla,
ExpandColumns,
)
from distilabel.steps.tasks import SelfInstruct
from huggingface_hub import hf_hub_download
def run(repo_id):
# Get super secret tokens
hub_token = os.environ.get("HF_TOKEN")
with open(
hf_hub_download(
repo_id=repo_id, filename="pipeline_params.json", repo_type="dataset"
),
"r",
) as f:
params = json.load(f)
self_instruct_base_url = params.get("self_instruct_base_url")
self_intruct_num_generations = params.get("self_instruct_num_generations", 2)
domain_expert_num_generations = params.get("domain_expert_num_generations", 2)
self_instruct_temperature = params.get("self_instruct_temperature", 0.9)
domain_expert_temperature = params.get("domain_expert_temperature", 0.9)
self_instruct_max_new_tokens = params.get("self_instruct_max_new_tokens", 1024)
domain_expert_max_new_tokens = params.get("domain_expert_max_new_tokens", 1024)
with open(
hf_hub_download(
repo_id=repo_id, filename="seed_data.json", repo_type="dataset"
),
"r",
) as f:
seed_data = json.load(f)
application_instruction = seed_data.get("application_instruction")
domain_expert_prompt = seed_data.get("domain_expert_prompt")
domain_name = seed_data.get("domain")
terms = seed_data.get("seed_terms")
with Pipeline(domain_name) as pipeline:
load_data = LoadDataFromDicts(
name="load_data",
batch_size=64,
data=[{"input": term} for term in terms],
)
self_instruct = SelfInstruct(
name="self_instruct",
num_instructions=self_intruct_num_generations,
input_batch_size=8,
llm=InferenceEndpointsLLM(
api_key=hub_token,
base_url=self_instruct_base_url,
),
application_description=application_instruction,
)
# Connect up the pipeline
load_data.connect(self_instruct)
# Run the pipeline
pipeline.run(
use_cache=False,
parameters={
"self_instruct": {
"llm": {
"generation_kwargs": {
"max_new_tokens": self_instruct_max_new_tokens,
"temperature": self_instruct_temperature,
},
}
},
},
)
if __name__ == "__main__":
run(sys.argv[1])