import sys import time import os import pandas as pd import requests from datasets import load_dataset, concatenate_datasets import argilla as rg from argilla.listeners import listener ### Configuration section ### # needed for pushing the validated data to HUB_DATASET_NAME HF_TOKEN = os.environ.get("HF_TOKEN") # The source dataset to read Alpaca translated examples SOURCE_DATASET = "LEL-A/translated_german_alpaca" # The name of the dataset in Argilla RG_DATASET_NAME = "translated-german-alpaca" # The name of the Hub dataset to push the validations every 20 min and keep the dataset synced HUB_DATASET_NAME = os.environ.get('HUB_DATASET_NAME', f"{SOURCE_DATASET}_validation") # The labels for the task (they can be extended if needed) LABELS = ["BAD INSTRUCTION", "INAPPROPRIATE", "ALL GOOD", "NOT SURE", "WRONG LANGUAGE"] @listener( dataset=RG_DATASET_NAME, query="status:Validated", execution_interval_in_seconds=1200, # interval to check the execution of `save_validated_to_hub` ) def save_validated_to_hub(records, ctx): if len(records) > 0: ds = rg.DatasetForTextClassification(records=records).to_datasets() if HF_TOKEN: print("Pushing the dataset") print(ds) ds.push_to_hub(HUB_DATASET_NAME, token=HF_TOKEN) else: print("SET HF_TOKEN and HUB_DATASET_NAME TO SYNC YOUR DATASET!!!") else: print("NO RECORDS found") class LoadDatasets: def __init__(self, api_key, workspace="team"): rg.init(api_key=api_key, workspace=workspace) @staticmethod def load_somos(): # Leer el dataset del Hub try: print(f"Trying to sync with {HUB_DATASET_NAME}") old_ds = load_dataset(HUB_DATASET_NAME, split="train") except Exception as e: print(f"Not possible to sync with {HUB_DATASET_NAME}") print(e) old_ds = None print(f"Loading dataset: {SOURCE_DATASET}") dataset = load_dataset(SOURCE_DATASET, split="train") if old_ds: print("Concatenating datasets") dataset = concatenate_datasets([dataset, old_ds]) print("Concatenated dataset is:") print(dataset) dataset = dataset.remove_columns("metrics") records = rg.DatasetForTextClassification.from_datasets(dataset) settings = rg.TextClassificationSettings( label_schema=LABELS ) print(f"Configuring dataset: {RG_DATASET_NAME}") rg.configure_dataset(name=RG_DATASET_NAME, settings=settings, workspace="team") # Log the dataset print(f"Logging dataset: {RG_DATASET_NAME}") rg.log( records, name=RG_DATASET_NAME, tags={"description": "Alpaca dataset to clean up"}, batch_size=200 ) # run listener save_validated_to_hub.start() if __name__ == "__main__": API_KEY = sys.argv[1] LOAD_DATASETS = sys.argv[2] if LOAD_DATASETS.lower() == "none": print("No datasets being loaded") else: while True: try: response = requests.get("http://0.0.0.0:6900/") if response.status_code == 200: ld = LoadDatasets(API_KEY) ld.load_somos() break except requests.exceptions.ConnectionError: pass except Exception as e: print(e) time.sleep(10) pass time.sleep(5) while True: time.sleep(60)