Spaces:
Sleeping
Sleeping
import sys | |
import time | |
import os | |
import pandas as pd | |
import requests | |
from datasets import load_dataset, concatenate_datasets | |
import argilla as rg | |
from argilla.listeners import listener | |
### Configuration section ### | |
# needed for pushing the validated data to HUB_DATASET_NAME | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
# The source dataset to read Alpaca translated examples | |
SOURCE_DATASET = "nihalbaig/alpaca-bangla" | |
# The name of the dataset in Argilla | |
RG_DATASET_NAME = "alpaca-bangla" | |
# The name of the Hub dataset to push the validations every 20 min and keep the dataset synced | |
HUB_DATASET_NAME = os.environ.get('HUB_DATASET_NAME', f"{SOURCE_DATASET}_validation") | |
# The labels for the task (they can be extended if needed) | |
LABELS = ["BAD INSTRUCTION", "INAPPROPRIATE", "ALL GOOD", "NOT SURE", "WRONG LANGUAGE"] | |
def save_validated_to_hub(records, ctx): | |
if len(records) > 0: | |
ds = rg.DatasetForTextClassification(records=records).to_datasets() | |
if HF_TOKEN: | |
print("Pushing the dataset") | |
print(ds) | |
ds.push_to_hub(HUB_DATASET_NAME, token=HF_TOKEN) | |
else: | |
print("SET HF_TOKEN and HUB_DATASET_NAME TO SYNC YOUR DATASET!!!") | |
else: | |
print("NO RECORDS found") | |
class LoadDatasets: | |
def __init__(self, api_key, workspace="team"): | |
rg.init(api_key=api_key, workspace=workspace) | |
def load_somos(): | |
# Leer el dataset del Hub | |
try: | |
print(f"Trying to sync with {HUB_DATASET_NAME}") | |
old_ds = load_dataset(HUB_DATASET_NAME, split="train") | |
except Exception as e: | |
print(f"Not possible to sync with {HUB_DATASET_NAME}") | |
print(e) | |
old_ds = None | |
print(f"Loading dataset: {SOURCE_DATASET}") | |
dataset = load_dataset(SOURCE_DATASET, split="train") | |
if old_ds: | |
print("Concatenating datasets") | |
dataset = concatenate_datasets([dataset, old_ds]) | |
print("Concatenated dataset is:") | |
print(dataset) | |
dataset = dataset.remove_columns("metrics") | |
records = rg.DatasetForTextClassification.from_datasets(dataset) | |
settings = rg.TextClassificationSettings( | |
label_schema=LABELS | |
) | |
print(f"Configuring dataset: {RG_DATASET_NAME}") | |
rg.configure_dataset(name=RG_DATASET_NAME, settings=settings, workspace="team") | |
# Log the dataset | |
print(f"Logging dataset: {RG_DATASET_NAME}") | |
rg.log( | |
records, | |
name=RG_DATASET_NAME, | |
tags={"description": "Alpaca dataset to clean up"}, | |
batch_size=200 | |
) | |
# run listener | |
save_validated_to_hub.start() | |
if __name__ == "__main__": | |
API_KEY = sys.argv[1] | |
LOAD_DATASETS = sys.argv[2] | |
if LOAD_DATASETS.lower() == "none": | |
print("No datasets being loaded") | |
else: | |
while True: | |
try: | |
response = requests.get("http://0.0.0.0:6900/") | |
if response.status_code == 200: | |
ld = LoadDatasets(API_KEY) | |
ld.load_somos() | |
break | |
except requests.exceptions.ConnectionError: | |
pass | |
except Exception as e: | |
print(e) | |
time.sleep(10) | |
pass | |
time.sleep(5) | |
while True: | |
time.sleep(60) |