Spaces:
Build error
Build error
File size: 3,548 Bytes
36821d3 c9bd449 d2e7f91 c9bd449 6199610 dff7018 36821d3 dff7018 aaaeb76 6199610 36821d3 dff7018 36821d3 dff7018 36821d3 dff7018 78f9744 dff7018 36821d3 6199610 78f9744 6199610 36821d3 c0fa328 36821d3 aaaeb76 21cc5dc 778c655 21cc5dc 36821d3 dff7018 36821d3 dff7018 36821d3 21cc5dc 36821d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import json
import gradio as gr
from distilabel.llms import InferenceEndpointsLLM
from distilabel.steps.tasks.argillalabeller import ArgillaLabeller
llm = InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
generation_kwargs={"max_new_tokens": 1000 * 4},
)
task = ArgillaLabeller(llm=llm)
task.load()
def load_examples():
with open("examples.json", "r") as f:
return json.load(f)
# Create Gradio examples
examples = load_examples()
def process_fields(fields):
if isinstance(fields, str):
fields = json.loads(fields)
if isinstance(fields, dict):
fields = [fields]
return [field if isinstance(field, dict) else json.loads(field) for field in fields]
def process_records_gradio(records, fields, question, example_records=None):
try:
# Convert string inputs to dictionaries
records = json.loads(records)
example_records = json.loads(example_records) if example_records else None
fields = process_fields(fields) if fields else None
question = json.loads(question) if question else None
if not fields and not question:
return "Error: Either fields or question must be provided"
runtime_parameters = {"fields": fields, "question": question}
if example_records:
runtime_parameters["example_records"] = example_records
task.set_runtime_parameters(runtime_parameters)
results = []
output = task.process(inputs=[{"record": record} for record in records])
for _ in range(len(records)):
entry = next(output)[0]
if entry["suggestions"]:
results.append(entry["suggestions"])
return json.dumps({"results": results}, indent=2)
except Exception as e:
raise Exception(f"Error: {str(e)}")
description = """
An example workflow for JSON payload.
```python
import json
import os
from gradio_client import Client
import argilla as rg
# Initialize Argilla client
client = rg.Argilla(
api_key=os.environ["ARGILLA_API_KEY"], api_url=os.environ["ARGILLA_API_URL"]
)
# Load the dataset
dataset = client.datasets(name="my_dataset", workspace="my_workspace")
# Prepare example data
example_field = dataset.settings.fields["my_input_field"].serialize()
example_question = dataset.settings.questions["my_question_to_predict"].serialize()
payload = {
"records": [next(dataset.records()).to_dict()],
"fields": [example_field],
"question": example_question,
}
# Use gradio client to process the data
client = Client("davidberenstein1957/distilabel-argilla-labeller")
result = client.predict(
records=json.dumps(payload["records"]),
example_records=json.dumps(payload["example_records"]),
fields=json.dumps(payload["fields"]),
question=json.dumps(payload["question"]),
api_name="/predict"
)
```
"""
interface = gr.Interface(
fn=process_records_gradio,
inputs=[
gr.Code(label="Records (JSON)", language="json", lines=5),
gr.Code(label="Example Records (JSON, optional)", language="json", lines=5),
gr.Code(label="Fields (JSON, optional)", language="json"),
gr.Code(label="Question (JSON, optional)", language="json"),
],
examples=examples,
outputs=gr.Code(label="Suggestions", language="json", lines=10),
title="Distilabel - ArgillaLabeller - Record Processing Interface",
description=description,
)
if __name__ == "__main__":
interface.launch()
|