davidberenstein1957 HF staff commited on
Commit
6199610
1 Parent(s): c9bd449

fix: working example with argillalabeller updates

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -7,7 +7,7 @@ from distilabel.steps.tasks.argillalabeller import ArgillaLabeller
7
  llm = InferenceEndpointsLLM(
8
  model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
9
  tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
10
- generation_kwargs={"max_new_tokens": 1000 * 128},
11
  )
12
  task = ArgillaLabeller(llm=llm)
13
  task.load()
@@ -30,7 +30,7 @@ def process_fields(fields):
30
  return [field if isinstance(field, dict) else json.loads(field) for field in fields]
31
 
32
 
33
- def process_records_gradio(records, example_records, fields, question):
34
  try:
35
  # Convert string inputs to dictionaries
36
  records = json.loads(records)
@@ -48,16 +48,15 @@ def process_records_gradio(records, example_records, fields, question):
48
  task.set_runtime_parameters(runtime_parameters)
49
 
50
  results = []
51
- output = task.process(inputs=[{"records": record} for record in records])
52
  for _ in range(len(records)):
53
  entry = next(output)[0]
54
  if entry["suggestions"]:
55
- results.append(entry["suggestions"].serialize())
56
 
57
  return json.dumps({"results": results}, indent=2)
58
  except Exception as e:
59
  raise Exception(f"Error: {str(e)}")
60
- return f"Error: {str(e)}"
61
 
62
 
63
  description = """
 
7
  llm = InferenceEndpointsLLM(
8
  model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
9
  tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
10
+ generation_kwargs={"max_new_tokens": 1000 * 4},
11
  )
12
  task = ArgillaLabeller(llm=llm)
13
  task.load()
 
30
  return [field if isinstance(field, dict) else json.loads(field) for field in fields]
31
 
32
 
33
+ def process_records_gradio(records, fields, question, example_records=None):
34
  try:
35
  # Convert string inputs to dictionaries
36
  records = json.loads(records)
 
48
  task.set_runtime_parameters(runtime_parameters)
49
 
50
  results = []
51
+ output = task.process(inputs=[{"record": record} for record in records])
52
  for _ in range(len(records)):
53
  entry = next(output)[0]
54
  if entry["suggestions"]:
55
+ results.append(entry["suggestions"])
56
 
57
  return json.dumps({"results": results}, indent=2)
58
  except Exception as e:
59
  raise Exception(f"Error: {str(e)}")
 
60
 
61
 
62
  description = """