Spaces:
Sleeping
Sleeping
File size: 2,667 Bytes
39e6ae5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import os
import json
import sys
from distilabel.llms import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import (
LoadDataFromDicts,
TextGenerationToArgilla,
ExpandColumns,
)
from distilabel.steps.tasks import SelfInstruct
from huggingface_hub import hf_hub_download
def run(repo_id):
# Get super secret tokens
hub_token = os.environ.get("HF_TOKEN")
with open(
hf_hub_download(
repo_id=repo_id, filename="pipeline_params.json", repo_type="dataset"
),
"r",
) as f:
params = json.load(f)
self_instruct_base_url = params.get("self_instruct_base_url")
self_intruct_num_generations = params.get("self_instruct_num_generations", 2)
domain_expert_num_generations = params.get("domain_expert_num_generations", 2)
self_instruct_temperature = params.get("self_instruct_temperature", 0.9)
domain_expert_temperature = params.get("domain_expert_temperature", 0.9)
self_instruct_max_new_tokens = params.get("self_instruct_max_new_tokens", 1024)
domain_expert_max_new_tokens = params.get("domain_expert_max_new_tokens", 1024)
with open(
hf_hub_download(
repo_id=repo_id, filename="seed_data.json", repo_type="dataset"
),
"r",
) as f:
seed_data = json.load(f)
application_instruction = seed_data.get("application_instruction")
domain_expert_prompt = seed_data.get("domain_expert_prompt")
domain_name = seed_data.get("domain")
terms = seed_data.get("seed_terms")
with Pipeline(domain_name) as pipeline:
load_data = LoadDataFromDicts(
name="load_data",
batch_size=64,
data=[{"input": term} for term in terms],
)
self_instruct = SelfInstruct(
name="self_instruct",
num_instructions=self_intruct_num_generations,
input_batch_size=8,
llm=InferenceEndpointsLLM(
api_key=hub_token,
base_url=self_instruct_base_url,
),
application_description=application_instruction,
)
# Connect up the pipeline
load_data.connect(self_instruct)
# Run the pipeline
pipeline.run(
use_cache=False,
parameters={
"self_instruct": {
"llm": {
"generation_kwargs": {
"max_new_tokens": self_instruct_max_new_tokens,
"temperature": self_instruct_temperature,
},
}
},
},
)
if __name__ == "__main__":
run(sys.argv[1])
|