Spaces:
Build error
Build error
File size: 4,177 Bytes
36821d3 02f562e 36821d3 e9178fb d2e7f91 61f9c80 78f9744 dff7018 e9178fb 61f9c80 e9178fb 61f9c80 e9178fb 62ab45c dff7018 36821d3 dff7018 aaaeb76 dff7018 36821d3 dff7018 36821d3 dff7018 36821d3 dff7018 78f9744 dff7018 36821d3 78f9744 36821d3 aaaeb76 21cc5dc 778c655 21cc5dc 36821d3 dff7018 36821d3 dff7018 36821d3 21cc5dc 36821d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import json
import os
import gradio as gr
from distilabel.llms import LlamaCppLLM
from distilabel.steps.tasks.argillalabeller import ArgillaLabeller
file_path = os.path.join(os.path.dirname(__file__), "Qwen2-5-0.5B-Instruct-f16.gguf")
download_url = "https://huggingface.co/gaianet/Qwen2.5-0.5B-Instruct-GGUF/resolve/main/Qwen2.5-0.5B-Instruct-Q8_0.gguf?download=true"
if not os.path.exists(file_path):
import requests
import tqdm
response = requests.get(download_url, stream=True)
total_length = int(response.headers.get("content-length"))
with open(file_path, "wb") as f:
for chunk in tqdm.tqdm(
response.iter_content(chunk_size=1024 * 1024),
total=total_length / (1024 * 1024),
unit="KB",
unit_scale=True,
):
f.write(chunk)
llm = LlamaCppLLM(
model_path=file_path,
n_gpu_layers=-1,
# n_ctx=1024 * 128,
generation_kwargs={"max_new_tokens": 1024 * 128},
)
task = ArgillaLabeller(llm=llm)
task.load()
def load_examples():
with open("examples.json", "r") as f:
return json.load(f)
# Create Gradio examples
examples = load_examples()
def process_fields(fields):
if isinstance(fields, str):
fields = json.loads(fields)
if isinstance(fields, dict):
fields = [fields]
return [field if isinstance(field, dict) else json.loads(field) for field in fields]
def process_records_gradio(records, example_records, fields, question):
try:
# Convert string inputs to dictionaries
records = json.loads(records)
example_records = json.loads(example_records) if example_records else None
fields = process_fields(fields) if fields else None
question = json.loads(question) if question else None
if not fields and not question:
return "Error: Either fields or question must be provided"
runtime_parameters = {"fields": fields, "question": question}
if example_records:
runtime_parameters["example_records"] = example_records
task.set_runtime_parameters(runtime_parameters)
results = []
output = task.process(inputs=[{"records": record} for record in records])
for _ in range(len(records)):
entry = next(output)[0]
if entry["suggestions"]:
results.append(entry["suggestions"].serialize())
return json.dumps({"results": results}, indent=2)
except Exception as e:
return f"Error: {str(e)}"
description = """
An example workflow for JSON payload.
```python
import json
import os
from gradio_client import Client
import argilla as rg
# Initialize Argilla client
client = rg.Argilla(
api_key=os.environ["ARGILLA_API_KEY"], api_url=os.environ["ARGILLA_API_URL"]
)
# Load the dataset
dataset = client.datasets(name="my_dataset", workspace="my_workspace")
# Prepare example data
example_field = dataset.settings.fields["my_input_field"].serialize()
example_question = dataset.settings.questions["my_question_to_predict"].serialize()
payload = {
"records": [next(dataset.records()).to_dict()],
"fields": [example_field],
"question": example_question,
}
# Use gradio client to process the data
client = Client("davidberenstein1957/distilabel-argilla-labeller")
result = client.predict(
records=json.dumps(payload["records"]),
example_records=json.dumps(payload["example_records"]),
fields=json.dumps(payload["fields"]),
question=json.dumps(payload["question"]),
api_name="/predict"
)
```
"""
interface = gr.Interface(
fn=process_records_gradio,
inputs=[
gr.Code(label="Records (JSON)", language="json", lines=5),
gr.Code(label="Example Records (JSON, optional)", language="json", lines=5),
gr.Code(label="Fields (JSON, optional)", language="json"),
gr.Code(label="Question (JSON, optional)", language="json"),
],
examples=examples,
outputs=gr.Code(label="Suggestions", language="json", lines=10),
title="Distilabel - ArgillaLabeller - Record Processing Interface",
description=description,
)
if __name__ == "__main__":
interface.launch()
|