File size: 3,112 Bytes
b9b841c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import sys
import time
import os

import argilla as rg
import pandas as pd
import requests
from datasets import load_dataset, concatenate_datasets

from argilla.listeners import listener

HF_TOKEN = os.environ.get("HF_TOKEN")
HUB_DATASET_NAME = os.environ.get('HUB_DATASET_NAME')

@listener(
    dataset="somos-alpaca-es", 
    query="status:Validated", # https://docs.argilla.io/en/latest/guides/features/queries.html
    execution_interval_in_seconds=1200, # interval to check the execution of `save_validated_to_hub`
)
def save_validated_to_hub(records, ctx):
    if len(records) > 0:
        ds = rg.DatasetForTextClassification(records=records).to_datasets()   
        if HF_TOKEN:
            print("Pushing the dataset")
            print(ds)
            ds.push_to_hub(HUB_DATASET_NAME, token=HF_TOKEN)
        else:
            print("SET HF_TOKEN and HUB_DATASET_NAME TO SYNC YOUR DATASET!!!")
    else:
        print("NO RECORDS found")

class LoadDatasets:
    def __init__(self, api_key, workspace="team"):
        rg.init(api_key=api_key, workspace=workspace)

    @staticmethod
    def load_somos():
        # Leer el dataset del Hub
        try:
            print(f"Trying to sync with {HUB_DATASET_NAME}")
            old_ds = load_dataset(HUB_DATASET_NAME, split="train")
        except Exception as e:
            print(f"Not possible to sync with {HUB_DATASET_NAME}")
            print(e)
            old_ds = None
            
        dataset = load_dataset("somosnlp/somos-clean-alpaca-es", split="train")
    
        
        if old_ds:
            print("Concatenating datasets")
            dataset = concatenate_datasets([dataset, old_ds])
            print("Concatenated dataset is:")
            print(dataset)
            
        dataset = dataset.remove_columns("metrics")
        records = rg.DatasetForTextClassification.from_datasets(dataset)

        settings = rg.TextClassificationSettings(
            label_schema=["BAD INSTRUCTION", "BAD INPUT", "BAD OUTPUT", "INAPPROPRIATE", "BIASED", "ALL GOOD"]
        )
        rg.configure_dataset(name="somos-alpaca-es", settings=settings, workspace="team")
        
        # Log the dataset
        rg.log(
            records,
            name="somos-alpaca-es",
            tags={"description": "SomosNLP Hackathon dataset"},
            batch_size=200
        )
        
        # run listener
        save_validated_to_hub.start()

if __name__ == "__main__":
    API_KEY = sys.argv[1]
    LOAD_DATASETS = sys.argv[2]

    if LOAD_DATASETS.lower() == "none":
        print("No datasets being loaded")
    else:
        while True:
            try:
                response = requests.get("http://0.0.0.0:6900/")
                if response.status_code == 200:
                    ld = LoadDatasets(API_KEY)
                    ld.load_somos()
                    break

            except requests.exceptions.ConnectionError:
                pass
            except Exception as e:
                print(e)
                time.sleep(10)
                pass

            time.sleep(5)
    while True:
        time.sleep(60)