argilla-direct / src /argilla_utils.py
Ben Burtenshaw
update to argilla on main repo
836f663
import argilla as rg
from datasets import load_dataset
from datasets import load_dataset
from src.dataset import (
load_split,
is_label,
is_rating,
is_int,
is_float,
get_feature_values,
get_feature_labels,
load_repo_id,
)
def define_dataset_setting(
dataset_name, field_columns, question_columns, metadata_columns, argilla_space_url
):
client = rg.Argilla(api_url=argilla_space_url, api_key="owner.apikey")
split = load_split()
fields, questions, metadata, vectors = [], [], [], []
mapping = {}
# Add field columns
for column_name in field_columns:
field_column_name = f"{column_name}_field"
fields.append(rg.TextField(name=field_column_name))
mapping[column_name] = field_column_name
# Add question columns
for question_type, question_column_name, column_name in question_columns:
if question_type == "Label":
values = get_feature_values(split, column_name)
titles = get_feature_labels(split, column_name)
labels = {str(l): feature for l, feature in zip(values, titles)}
questions.append(rg.LabelQuestion(name=question_column_name, labels=labels))
elif question_type == "Rating":
values = get_feature_values(split, column_name)
questions.append(
rg.RatingQuestion(name=question_column_name, values=values)
)
else:
questions.append(rg.TextQuestion(name=question_column_name))
if column_name in mapping:
column_name = f"{column_name}__"
mapping[column_name] = question_column_name
# Add metadata columns
if not metadata_columns:
metadata_columns = []
for metadata_type, metadata_name, column_name in metadata_columns:
if metadata_type == "Integer":
metadata.append(rg.IntegerMetadataProperty(name=metadata_name))
elif metadata_type == "Float":
metadata.append(rg.FloatMetadataProperty(name=metadata_name))
elif metadata_type == "Term":
values = list(map(str, get_feature_values(split, column_name)))
metadata.append(
rg.TermsMetadataProperty(name=metadata_name, options=values)
)
if column_name in mapping:
column_name = f"{column_name}__"
mapping[column_name] = metadata_name
settings = rg.Settings(fields=fields, questions=questions, metadata=metadata)
dataset = rg.Dataset(name=dataset_name, settings=settings, client=client)
if not dataset.exists():
dataset.create()
return str(settings.serialize()), mapping
def add_records(argilla_dataset_name, mapping, n_records, argilla_space_url):
client = rg.Argilla(api_url=argilla_space_url, api_key="owner.apikey")
split = load_split()
df = load_dataset(load_repo_id())[split].take(n_records).to_pandas()
dataset = client.datasets(argilla_dataset_name)
questions = dataset.settings.questions
for question in questions:
if question.name in mapping.values():
column_name = [k for k, v in mapping.items() if v == question.name][0]
column_name = column_name.replace("__", "")
if is_label(split, column_name):
df[column_name] = df[column_name].apply(str)
for source, target in mapping.items():
if source.endswith("__"):
df[source] = df[source.replace("__", "")]
records = df.to_dict(orient="records")
dataset.records.log(records, mapping=mapping)
return f"{len(df)} records added with mapping {mapping}"
def delete_dataset(argilla_dataset_name):
dataset = client.datasets(argilla_dataset_name)
dataset.delete()
return f"Dataset {argilla_dataset_name} deleted"