import os import json import sys from distilabel.llms import InferenceEndpointsLLM from distilabel.pipeline import Pipeline from distilabel.steps import ( LoadDataFromDicts, TextGenerationToArgilla, ExpandColumns, ) from distilabel.steps.tasks import SelfInstruct from huggingface_hub import hf_hub_download def run(repo_id): # Get super secret tokens hub_token = os.environ.get("HF_TOKEN") with open( hf_hub_download( repo_id=repo_id, filename="pipeline_params.json", repo_type="dataset" ), "r", ) as f: params = json.load(f) self_instruct_base_url = params.get("self_instruct_base_url") self_intruct_num_generations = params.get("self_instruct_num_generations", 2) domain_expert_num_generations = params.get("domain_expert_num_generations", 2) self_instruct_temperature = params.get("self_instruct_temperature", 0.9) domain_expert_temperature = params.get("domain_expert_temperature", 0.9) self_instruct_max_new_tokens = params.get("self_instruct_max_new_tokens", 1024) domain_expert_max_new_tokens = params.get("domain_expert_max_new_tokens", 1024) with open( hf_hub_download( repo_id=repo_id, filename="seed_data.json", repo_type="dataset" ), "r", ) as f: seed_data = json.load(f) application_instruction = seed_data.get("application_instruction") domain_expert_prompt = seed_data.get("domain_expert_prompt") domain_name = seed_data.get("domain") terms = seed_data.get("seed_terms") with Pipeline(domain_name) as pipeline: load_data = LoadDataFromDicts( name="load_data", batch_size=64, data=[{"input": term} for term in terms], ) self_instruct = SelfInstruct( name="self_instruct", num_instructions=self_intruct_num_generations, input_batch_size=8, llm=InferenceEndpointsLLM( api_key=hub_token, base_url=self_instruct_base_url, ), application_description=application_instruction, ) # Connect up the pipeline load_data.connect(self_instruct) # Run the pipeline pipeline.run( use_cache=False, parameters={ "self_instruct": { "llm": { "generation_kwargs": { "max_new_tokens": self_instruct_max_new_tokens, "temperature": self_instruct_temperature, }, } }, }, ) if __name__ == "__main__": run(sys.argv[1])