|
import sys |
|
import time |
|
import os |
|
|
|
import pandas as pd |
|
import requests |
|
from datasets import load_dataset, concatenate_datasets |
|
|
|
import argilla as rg |
|
from argilla.listeners import listener |
|
|
|
|
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
|
|
|
SOURCE_DATASET = "LEL-A/translated_german_alpaca" |
|
|
|
|
|
RG_DATASET_NAME = "translated-german-alpaca" |
|
|
|
|
|
HUB_DATASET_NAME = os.environ.get('HUB_DATASET_NAME', f"{SOURCE_DATASET}_validation") |
|
|
|
|
|
LABELS = ["BAD INSTRUCTION", "INAPPROPRIATE", "ALL GOOD", "NOT SURE", "WRONG LANGUAGE"] |
|
|
|
@listener( |
|
dataset=RG_DATASET_NAME, |
|
query="status:Validated", |
|
execution_interval_in_seconds=1200, |
|
) |
|
def save_validated_to_hub(records, ctx): |
|
if len(records) > 0: |
|
ds = rg.DatasetForTextClassification(records=records).to_datasets() |
|
if HF_TOKEN: |
|
print("Pushing the dataset") |
|
print(ds) |
|
ds.push_to_hub(HUB_DATASET_NAME, token=HF_TOKEN) |
|
else: |
|
print("SET HF_TOKEN and HUB_DATASET_NAME TO SYNC YOUR DATASET!!!") |
|
else: |
|
print("NO RECORDS found") |
|
|
|
class LoadDatasets: |
|
def __init__(self, api_key, workspace="team"): |
|
rg.init(api_key=api_key, workspace=workspace) |
|
|
|
@staticmethod |
|
def load_somos(): |
|
|
|
try: |
|
print(f"Trying to sync with {HUB_DATASET_NAME}") |
|
old_ds = load_dataset(HUB_DATASET_NAME, split="train") |
|
except Exception as e: |
|
print(f"Not possible to sync with {HUB_DATASET_NAME}") |
|
print(e) |
|
old_ds = None |
|
|
|
print(f"Loading dataset: {SOURCE_DATASET}") |
|
dataset = load_dataset(SOURCE_DATASET, split="train") |
|
|
|
|
|
if old_ds: |
|
print("Concatenating datasets") |
|
dataset = concatenate_datasets([dataset, old_ds]) |
|
print("Concatenated dataset is:") |
|
print(dataset) |
|
|
|
dataset = dataset.remove_columns("metrics") |
|
records = rg.DatasetForTextClassification.from_datasets(dataset) |
|
|
|
settings = rg.TextClassificationSettings( |
|
label_schema=LABELS |
|
) |
|
|
|
print(f"Configuring dataset: {RG_DATASET_NAME}") |
|
rg.configure_dataset(name=RG_DATASET_NAME, settings=settings, workspace="team") |
|
|
|
|
|
print(f"Logging dataset: {RG_DATASET_NAME}") |
|
rg.log( |
|
records, |
|
name=RG_DATASET_NAME, |
|
tags={"description": "Alpaca dataset to clean up"}, |
|
batch_size=200 |
|
) |
|
|
|
|
|
save_validated_to_hub.start() |
|
|
|
if __name__ == "__main__": |
|
API_KEY = sys.argv[1] |
|
LOAD_DATASETS = sys.argv[2] |
|
|
|
if LOAD_DATASETS.lower() == "none": |
|
print("No datasets being loaded") |
|
else: |
|
while True: |
|
try: |
|
response = requests.get("http://0.0.0.0:6900/") |
|
if response.status_code == 200: |
|
ld = LoadDatasets(API_KEY) |
|
ld.load_somos() |
|
break |
|
|
|
except requests.exceptions.ConnectionError: |
|
pass |
|
except Exception as e: |
|
print(e) |
|
time.sleep(10) |
|
pass |
|
|
|
time.sleep(5) |
|
while True: |
|
time.sleep(60) |