Spaces:
Build error
Build error
import json | |
import gradio as gr | |
from distilabel.llms import InferenceEndpointsLLM | |
from distilabel.steps.tasks.argillalabeller import ArgillaLabeller | |
llm = InferenceEndpointsLLM( | |
model_id="meta-llama/Meta-Llama-3.1-8B-Instruct", | |
tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct", | |
generation_kwargs={"max_new_tokens": 1000 * 4}, | |
) | |
task = ArgillaLabeller(llm=llm) | |
task.load() | |
def load_examples(): | |
with open("examples.json", "r") as f: | |
return json.load(f) | |
# Create Gradio examples | |
examples = load_examples() | |
def process_fields(fields): | |
if isinstance(fields, str): | |
fields = json.loads(fields) | |
if isinstance(fields, dict): | |
fields = [fields] | |
return [field if isinstance(field, dict) else json.loads(field) for field in fields] | |
def process_records_gradio(records, fields, question, example_records=None): | |
try: | |
# Convert string inputs to dictionaries | |
records = json.loads(records) | |
example_records = json.loads(example_records) if example_records else None | |
fields = process_fields(fields) if fields else None | |
question = json.loads(question) if question else None | |
if not fields and not question: | |
return "Error: Either fields or question must be provided" | |
runtime_parameters = {"fields": fields, "question": question} | |
if example_records: | |
runtime_parameters["example_records"] = example_records | |
task.set_runtime_parameters(runtime_parameters) | |
results = [] | |
output = task.process(inputs=[{"record": record} for record in records]) | |
for _ in range(len(records)): | |
entry = next(output)[0] | |
if entry["suggestions"]: | |
results.append(entry["suggestions"]) | |
return json.dumps({"results": results}, indent=2) | |
except Exception as e: | |
raise Exception(f"Error: {str(e)}") | |
description = """ | |
An example workflow for JSON payload. | |
```python | |
import json | |
import os | |
from gradio_client import Client | |
import argilla as rg | |
# Initialize Argilla client | |
client = rg.Argilla( | |
api_key=os.environ["ARGILLA_API_KEY"], api_url=os.environ["ARGILLA_API_URL"] | |
) | |
# Load the dataset | |
dataset = client.datasets(name="my_dataset", workspace="my_workspace") | |
# Prepare example data | |
example_field = dataset.settings.fields["my_input_field"].serialize() | |
example_question = dataset.settings.questions["my_question_to_predict"].serialize() | |
payload = { | |
"records": [next(dataset.records()).to_dict()], | |
"fields": [example_field], | |
"question": example_question, | |
} | |
# Use gradio client to process the data | |
client = Client("davidberenstein1957/distilabel-argilla-labeller") | |
result = client.predict( | |
records=json.dumps(payload["records"]), | |
example_records=json.dumps(payload["example_records"]), | |
fields=json.dumps(payload["fields"]), | |
question=json.dumps(payload["question"]), | |
api_name="/predict" | |
) | |
``` | |
""" | |
interface = gr.Interface( | |
fn=process_records_gradio, | |
inputs=[ | |
gr.Code(label="Records (JSON)", language="json", lines=5), | |
gr.Code(label="Example Records (JSON, optional)", language="json", lines=5), | |
gr.Code(label="Fields (JSON, optional)", language="json"), | |
gr.Code(label="Question (JSON, optional)", language="json"), | |
], | |
examples=examples, | |
outputs=gr.Code(label="Suggestions", language="json", lines=10), | |
title="Distilabel - ArgillaLabeller - Record Processing Interface", | |
description=description, | |
) | |
if __name__ == "__main__": | |
interface.launch() | |